In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy


In [2]:
mean = np.array([0.5, 0.5, 0.5])
std = np.array([0.25, 0.25, 0.25])

data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]),
}

In [3]:
data_dir = "E:/casper/raw_data_training"
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=0)
              for x in ['train', 'val']}

In [4]:
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(class_names)
print(device)

['IgA', 'MN']
cuda:0


In [5]:
def imshow(inp, title):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    plt.title(title)
    plt.show()


# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

# imshow(out, title=[class_names[x] for x in classes])

In [6]:
print(classes)

tensor([0, 0, 1, 0])


In [7]:
def train_model_ori(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                # print(labels.data)
                
                inputs = inputs.to(device)
                labels = labels.to(device)

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [8]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    best_recallmn = 0.0 # ***
    best_recallig = 0.0 # ***
    
    best_precmn = 0.0 # ***
    best_precig = 0.0 # ***
    
    best_f1mn = 0.0 # ***
    best_f1ig = 0.0 # ***

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            
            tp_positive = 0
            fp_positive = 0
            tn_negative = 0
            fn_negative = 0
            
            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)


                    # backward + optimize only if in training phase
                    if phase == 'train':
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
                # MN -> positive
                # iga -> negative
                tp_positive += torch.sum((preds == 1) & (labels.data == 1))
                fp_positive += torch.sum((preds == 1) & (labels.data == 0))
                tn_negative += torch.sum((preds == 0) & (labels.data == 0))
                fn_negative += torch.sum((preds == 0) & (labels.data == 1))
                    
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            
            epoch_recallmn = tp_positive.double() / (tp_positive + fn_negative) # ***
            epoch_recallig = tn_negative.double() / (tn_negative + fp_positive) # ***
            
            epoch_precmn = tp_positive.double() / (tp_positive + fp_positive) # ***
            epoch_precig = tn_negative.double() / (tn_negative + fn_negative) # ***
            
            epoch_f1mn = (2 * epoch_recallmn * epoch_precmn) / (epoch_recallmn + epoch_precmn) # ***
            epoch_f1ig = (2 * epoch_recallig * epoch_precig) / (epoch_recallig + epoch_precig) # ***
            
            # print('{} Loss: {:.4f} Acc: {:.4f} Recall_MGN: {:.4f} Recall_IGAN: {:.4f} Precision_MGN: {:.4f} Precision_IGAN: {:.4f} F1_MGN: {:.4f} F1_IGAN: {:.4f}'.format(
            #     phase, epoch_loss, epoch_acc, epoch_recallmn, epoch_recallig, epoch_precmn, epoch_precig, epoch_f1mn, epoch_f1ig)) # ***
            print('{} Loss: {:.4f} Accuracy: {:.4f} \n     Recall Precision F1_score\n MGN: {:.4f} {:.4f} {:.4f}\n IGAN:{:.4f} {:.4f} {:.4f}\n'.format(
                    phase, epoch_loss, epoch_acc, epoch_recallmn, epoch_precmn, epoch_f1mn, epoch_recallig, epoch_precig, epoch_f1ig)) # ***

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                
            if phase == 'val' and epoch_recallmn > best_recallmn: # ***
                best_recallmn = epoch_recallmn
            if phase == 'val' and epoch_recallig > best_recallig: # ***
                best_recallig = epoch_recallig  
                
            if phase == 'val' and epoch_precmn > best_precmn: # ***
                best_precmn = epoch_precmn
            if phase == 'val' and epoch_precig > best_precig: # ***
                best_precig = epoch_precig    
                
            if phase == 'val' and epoch_f1mn > best_f1mn: # ***
                best_f1mn = epoch_f1mn
            if phase == 'val' and epoch_f1ig > best_f1ig: # ***
                best_f1ig = epoch_f1ig   
                
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:.4f} \n     Recall Precision F1_score\n MGN: {:.4f} {:.4f} {:.4f}\n IGAN:{:.4f} {:.4f} {:.4f}\n'.format(
                best_acc, best_recallmn, best_precmn, best_f1mn, best_recallig, best_precig, best_f1ig)) # ***

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [9]:
def train_model3(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    # Variables to keep track of true positives, false positives, true negatives, false negatives
    tp_positive = 0
    fp_positive = 0
    tn_negative = 0
    fn_negative = 0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Calculate precision for positive label and negative label
                    if phase == 'val':
                        tp_positive += torch.sum((preds == 1) & (labels.data == 1))
                        fp_positive += torch.sum((preds == 1) & (labels.data == 0))
                        tn_negative += torch.sum((preds == 0) & (labels.data == 0))
                        fn_negative += torch.sum((preds == 0) & (labels.data == 1))

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # Calculate precision for positive label and negative label
    precision_positive = tp_positive.double() / (tp_positive + fp_positive)
    precision_negative = tn_negative.double() / (tn_negative + fn_negative)
    print('Precision (Positive Label): {:.4f}'.format(precision_positive))
    print('Precision (Negative Label): {:.4f}'.format(precision_negative))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [10]:
def train_mod(true_None, num_epoch):
    model = models.resnet18(weights=true_None)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 2)
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001)
    step_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
    model = train_model(model, criterion, optimizer, step_lr_scheduler, num_epochs=num_epoch)    
    return model

In [11]:
model = train_mod(None, 2)
torch.save(model.state_dict(), "res18_2n.pt")

Epoch 0/1
----------
train Loss: 0.6144 Accuracy: 0.6535 
     Recall Precision F1_score
 MGN: 0.3290 0.5534 0.4127
 IGAN:0.8441 0.6817 0.7542

val Loss: 0.7011 Accuracy: 0.5279 
     Recall Precision F1_score
 MGN: 0.2688 0.7562 0.3967
 IGAN:0.8816 0.4690 0.6123


Epoch 1/1
----------
train Loss: 0.5820 Accuracy: 0.6793 
     Recall Precision F1_score
 MGN: 0.4518 0.5865 0.5104
 IGAN:0.8129 0.7163 0.7615

val Loss: 0.7207 Accuracy: 0.5627 
     Recall Precision F1_score
 MGN: 0.3128 0.8164 0.4523
 IGAN:0.9039 0.4907 0.6361


Training complete in 3m 53s
Best val Acc: 0.5627 
     Recall Precision F1_score
 MGN: 0.3128 0.8164 0.4523
 IGAN:0.9039 0.4907 0.6361



In [12]:
model = train_mod(None, 5)
torch.save(model.state_dict(), "res18_5n.pt")

Epoch 0/4
----------
train Loss: 0.6209 Accuracy: 0.6544 
     Recall Precision F1_score
 MGN: 0.2996 0.5620 0.3908
 IGAN:0.8629 0.6772 0.7588

val Loss: 0.7132 Accuracy: 0.5540 
     Recall Precision F1_score
 MGN: 0.3794 0.7139 0.4955
 IGAN:0.7925 0.4833 0.6004


Epoch 1/4
----------
train Loss: 0.5874 Accuracy: 0.6815 
     Recall Precision F1_score
 MGN: 0.4361 0.5951 0.5034
 IGAN:0.8257 0.7137 0.7656

val Loss: 0.6802 Accuracy: 0.6041 
     Recall Precision F1_score
 MGN: 0.8543 0.6126 0.7135
 IGAN:0.2624 0.5688 0.3592


Epoch 2/4
----------
train Loss: 0.5649 Accuracy: 0.7022 
     Recall Precision F1_score
 MGN: 0.4947 0.6228 0.5514
 IGAN:0.8240 0.7352 0.7771

val Loss: 0.7038 Accuracy: 0.6222 
     Recall Precision F1_score
 MGN: 0.5327 0.7400 0.6194
 IGAN:0.7444 0.5385 0.6249


Epoch 3/4
----------
train Loss: 0.5584 Accuracy: 0.7046 
     Recall Precision F1_score
 MGN: 0.4996 0.6264 0.5559
 IGAN:0.8250 0.7373 0.7787

val Loss: 0.6115 Accuracy: 0.6882 
     Recall Precision F

In [13]:
model = train_mod(None, 20)
torch.save(model.state_dict(), "res18_20n.pt")

Epoch 0/19
----------
train Loss: 0.6115 Accuracy: 0.6578 
     Recall Precision F1_score
 MGN: 0.3548 0.5592 0.4342
 IGAN:0.8357 0.6880 0.7547

val Loss: 0.6540 Accuracy: 0.6302 
     Recall Precision F1_score
 MGN: 0.6834 0.6783 0.6809
 IGAN:0.5575 0.5633 0.5603


Epoch 1/19
----------
train Loss: 0.5850 Accuracy: 0.6773 
     Recall Precision F1_score
 MGN: 0.4451 0.5840 0.5052
 IGAN:0.8137 0.7140 0.7606

val Loss: 0.6459 Accuracy: 0.6338 
     Recall Precision F1_score
 MGN: 0.6244 0.7070 0.6631
 IGAN:0.6467 0.5577 0.5989


Epoch 2/19
----------
train Loss: 0.5738 Accuracy: 0.6933 
     Recall Precision F1_score
 MGN: 0.4740 0.6101 0.5335
 IGAN:0.8221 0.7269 0.7715

val Loss: 0.6272 Accuracy: 0.6476 
     Recall Precision F1_score
 MGN: 0.8982 0.6384 0.7463
 IGAN:0.3053 0.6873 0.4228


Epoch 3/19
----------
train Loss: 0.5673 Accuracy: 0.6956 
     Recall Precision F1_score
 MGN: 0.4885 0.6108 0.5429
 IGAN:0.8172 0.7312 0.7718

val Loss: 0.6803 Accuracy: 0.6120 
     Recall Precisi

In [14]:
model = train_mod(True, 2)
torch.save(model.state_dict(), "res18_2t.pt")

Epoch 0/1
----------




train Loss: 0.5522 Accuracy: 0.7082 
     Recall Precision F1_score
 MGN: 0.5200 0.6276 0.5688
 IGAN:0.8187 0.7439 0.7795

val Loss: 0.5011 Accuracy: 0.7542 
     Recall Precision F1_score
 MGN: 0.8505 0.7547 0.7998
 IGAN:0.6226 0.7531 0.6817


Epoch 1/1
----------
train Loss: 0.5005 Accuracy: 0.7552 
     Recall Precision F1_score
 MGN: 0.6091 0.6923 0.6481
 IGAN:0.8410 0.7856 0.8123

val Loss: 0.4555 Accuracy: 0.7941 
     Recall Precision F1_score
 MGN: 0.7814 0.8497 0.8141
 IGAN:0.8113 0.7311 0.7691


Training complete in 3m 52s
Best val Acc: 0.7941 
     Recall Precision F1_score
 MGN: 0.8505 0.8497 0.8141
 IGAN:0.8113 0.7531 0.7691



In [15]:
model = train_mod(True, 5)
torch.save(model.state_dict(), "res18_5t.pt")

Epoch 0/4
----------
train Loss: 0.5604 Accuracy: 0.7036 
     Recall Precision F1_score
 MGN: 0.5053 0.6225 0.5578
 IGAN:0.8200 0.7384 0.7770

val Loss: 0.5208 Accuracy: 0.7302 
     Recall Precision F1_score
 MGN: 0.8467 0.7294 0.7837
 IGAN:0.5712 0.7319 0.6416


Epoch 1/4
----------
train Loss: 0.4965 Accuracy: 0.7577 
     Recall Precision F1_score
 MGN: 0.6160 0.6945 0.6529
 IGAN:0.8409 0.7885 0.8139

val Loss: 0.5018 Accuracy: 0.7759 
     Recall Precision F1_score
 MGN: 0.7286 0.8618 0.7897
 IGAN:0.8405 0.6941 0.7603


Epoch 2/4
----------
train Loss: 0.4626 Accuracy: 0.7801 
     Recall Precision F1_score
 MGN: 0.6620 0.7210 0.6902
 IGAN:0.8495 0.8106 0.8296

val Loss: 0.4681 Accuracy: 0.7955 
     Recall Precision F1_score
 MGN: 0.7852 0.8492 0.8159
 IGAN:0.8096 0.7341 0.7700


Epoch 3/4
----------
train Loss: 0.4417 Accuracy: 0.7942 
     Recall Precision F1_score
 MGN: 0.6852 0.7396 0.7114
 IGAN:0.8583 0.8228 0.8401

val Loss: 0.4344 Accuracy: 0.8013 
     Recall Precision F

In [16]:
model = train_mod(True, 20)
torch.save(model.state_dict(), "res18_20t.pt")

Epoch 0/19
----------
train Loss: 0.5542 Accuracy: 0.7165 
     Recall Precision F1_score
 MGN: 0.5300 0.6414 0.5804
 IGAN:0.8260 0.7495 0.7859

val Loss: 0.4957 Accuracy: 0.7679 
     Recall Precision F1_score
 MGN: 0.7337 0.8439 0.7849
 IGAN:0.8148 0.6914 0.7480


Epoch 1/19
----------
train Loss: 0.4923 Accuracy: 0.7631 
     Recall Precision F1_score
 MGN: 0.6269 0.7012 0.6620
 IGAN:0.8431 0.7937 0.8177

val Loss: 0.5286 Accuracy: 0.7650 
     Recall Precision F1_score
 MGN: 0.6658 0.9014 0.7659
 IGAN:0.9005 0.6637 0.7642


Epoch 2/19
----------
train Loss: 0.4597 Accuracy: 0.7839 
     Recall Precision F1_score
 MGN: 0.6630 0.7286 0.6942
 IGAN:0.8549 0.8120 0.8329

val Loss: 0.4387 Accuracy: 0.7948 
     Recall Precision F1_score
 MGN: 0.7676 0.8618 0.8120
 IGAN:0.8319 0.7239 0.7741


Epoch 3/19
----------
train Loss: 0.4357 Accuracy: 0.8014 
     Recall Precision F1_score
 MGN: 0.6942 0.7504 0.7213
 IGAN:0.8644 0.8280 0.8458

val Loss: 0.4357 Accuracy: 0.8013 
     Recall Precisi

In [72]:
# #### Finetuning the convnet ####
# # Load a pretrained model and reset final fully connected layer.

# model = models.resnet18(weights=True)
# # print(model)
# num_ftrs = model.fc.in_features
# # Here the size of each output sample is set to 2.
# # Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
# model.fc = nn.Linear(num_ftrs, 2)
# model = model.to(device)
# criterion = nn.CrossEntropyLoss()
# # Observe that all parameters are being optimized
# optimizer = optim.SGD(model.parameters(), lr=0.001)
# # StepLR Decays the learning rate of each parameter group by gamma every step_size epochs
# # Decay LR by a factor of 0.1 every 7 epochs
# # Learning rate scheduling should be applied after optimizer’s update
# # e.g., you should write your code this way:
# # for epoch in range(100):
# #     train(...)
# #     validate(...)
# #     scheduler.step()

# step_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
# model = train_model(model, criterion, optimizer, step_lr_scheduler, num_epochs=2)
# # torch.save(model.state_dict(), "res18_1.pt")

In [73]:
# #### ConvNet as fixed feature extractor ####
# # Here, we need to freeze all the network except the final layer.
# # We need to set requires_grad == False to freeze the parameters so that the gradients are not computed in backward()
# model_conv = torchvision.models.resnet18(weights=True)
# for param in model_conv.parameters():
#     param.requires_grad = False

# # Parameters of newly constructed modules have requires_grad=True by default
# num_ftrs = model_conv.fc.in_features
# model_conv.fc = nn.Linear(num_ftrs, 2)

# model_conv = model_conv.to(device)

# criterion = nn.CrossEntropyLoss()

# # Observe that only parameters of final layer are being optimized as
# # opposed to before.
# optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# # Decay LR by a factor of 0.1 every 7 epochs
# exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

# model_conv = train_model(model_conv, criterion, optimizer_conv,
#                          exp_lr_scheduler, num_epochs=2)

In [74]:
# model_v2 = models.resnet18(weights=True)
# num_ftrs = model_v2.fc.in_features
# model_v2.fc = nn.Linear(num_ftrs, 2)
# model_v2 = model_v2.to(device)
# criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(model_v2.parameters(), lr=0.001)
# step_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
# model_v2 = train_model(model_v2, criterion, optimizer, step_lr_scheduler, num_epochs=20)


In [75]:
# model_conv_v2 = torchvision.models.resnet18(weights=True)
# for param in model_conv.parameters():
#     param.requires_grad = False

# # Parameters of newly constructed modules have requires_grad=True by default
# num_ftrs = model_conv_v2.fc.in_features
# model_conv_v2.fc = nn.Linear(num_ftrs, 2)

# model_conv_v2 = model_conv_v2.to(device)

# criterion = nn.CrossEntropyLoss()

# # Observe that only parameters of final layer are being optimized as
# # opposed to before.
# optimizer_conv = optim.SGD(model_conv_v2.fc.parameters(), lr=0.001, momentum=0.9)

# # Decay LR by a factor of 0.1 every 7 epochs
# exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

# model_conv_v2 = train_model(model_conv_v2, criterion, optimizer_conv,
#                          exp_lr_scheduler, num_epochs=20)

In [76]:
# model_v3 = models.resnet18(weights=False)
# num_ftrs = model_v3.fc.in_features
# model_v3.fc = nn.Linear(num_ftrs, 2)
# model_v3 = model_v3.to(device)
# criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(model_v3.parameters(), lr=0.001)
# step_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
# model_v3 = train_model(model_v3, criterion, optimizer, step_lr_scheduler, num_epochs=2)
