In [1]:
import torchvision.models as models
import os
import sys
from torch.utils.data import Dataset,DataLoader
from easydict import EasyDict as edict
import json
import torchvision
import torch.nn as nn
import torch
import torch.nn.functional as F
import copy
from tqdm import tqdm
from torchvision import datasets, transforms
from sklearn.metrics import roc_auc_score
import torch.utils.data as data_utils

In [2]:
import warnings
warnings.filterwarnings('ignore')

In [3]:
from __future__ import print_function, division

import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import matplotlib.pyplot as plt
import time
from sklearn import metrics

In [4]:
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [5]:
# from data.dataset import ImageDataset

In [6]:
device = torch.device("cuda")

In [7]:
torch.cuda.device_count()

1

In [8]:
batch_size = 16
random_seed = 8

In [9]:
# mimic = 0
m = 10000
n = 20000
t = 1024

In [10]:
class Model(nn.Module):
    def __init__(self,t):
        super().__init__()
        self.res = models.resnet152(pretrained=True)
        num_ftrs = self.res.fc.in_features
        self.res.fc = nn.Linear(num_ftrs, t)

        self.linear = nn.Linear(t, 1)
        
        self.norm = nn.LayerNorm(t)
        
        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2)

    def forward(self, x):
        x = self.leaky_relu(self.res(x))
        x = self.leaky_relu(self.norm(x))
        x = self.leaky_relu(self.linear(x))
        return x

In [11]:
# vgg19 = models.vgg19(pretrained=True)
# mod = list(vgg19.classifier.children())
# mod.pop()
# mod.append(torch.nn.Linear(4096, 1024))
# mod.append(torch.nn.Linear(1024, 256))
# mod.append(torch.nn.Linear(256, 1))
# new_classifier = torch.nn.Sequential(*mod)
# vgg19.classifier = new_classifier
# model = vgg19.cuda()

In [12]:
model = Model(t)

In [13]:
model = model.to(device)

In [14]:
transform = {
    'train':
    transforms.Compose(
        [
#             transforms.ToPILImage(),
            transforms.Resize((256,256)),
            transforms.RandomResizedCrop((224),scale=(0.9,1)),
            transforms.RandomHorizontalFlip(),
#             transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ]),
    'val':
    transforms.Compose(
    [
        transforms.Resize((256,256)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ])
}

In [15]:
real_mimic_train_path = '/shared/rsaas/nschiou2/CXR/data/train/mimic'
real_chexpert_train_path = '/shared/rsaas/nschiou2/CXR/data/train/chexpert'
real_mimic_test_path = '/shared/rsaas/nschiou2/CXR/data/test/mimic'
real_chexpert_test_path = '/shared/rsaas/nschiou2/CXR/data/test/chexpert'

In [16]:
# syn = datasets.ImageFolder('/home/diwenxu2/data_v3/train/syn', transform=transform['train'])
real_mimic = datasets.ImageFolder(real_mimic_train_path, transform=transform['train'])
chexpert = datasets.ImageFolder(real_chexpert_train_path, transform=transform['train'])
# syn_chexpert = datasets.ImageFolder('/home/diwenxu2/syn_chex', transform=transform['train'])
mimic_test = datasets.ImageFolder(real_mimic_test_path, transform=transform['val'])
# comb = datasets.ImageFolder('/shared/rsaas/diwenxu2/data_v3/train/comb', transform=transform['train'])

In [17]:
index_c = np.arange(len(real_mimic))
np.random.seed(random_seed)
np.random.shuffle(index_c)
# index_s = np.arange(len(syn))
# np.random.seed(random_seed)
# np.random.shuffle(index_s)

In [18]:
# if mimic:
#     m = n-20
    
# else:
#     m = n
# m = n

In [19]:
# selected = data_utils.Subset(chexpert, index[:m])
# selected = data_utils.Subset(syn_chexpert, index[:m])
# selected = data_utils.Subset(syn, index[:m])
# selected = data_utils.Subset(chexpert, index[:m])
# data_syn = data_utils.Subset(syn, index_s[:m])
data_chex = data_utils.Subset(chexpert,index_c[:n])

In [20]:
# if mimic:
#     dataset = torch.utils.data.ConcatDataset([chexpert,selected])
# else:
#     dataset = selected
# dataset = torch.utils.data.ConcatDataset([data_syn,data_chex])
dataset = data_utils.Subset(real_mimic, index_c[:n])

In [21]:
total = len(dataset)
total

20000

In [22]:
tra = int((total)/5*4)
val = int(total-tra)

In [23]:
tra_set, va_set = torch.utils.data.random_split(dataset, [tra,val],generator=torch.Generator().manual_seed(random_seed))
# tra_set = dataset
# va_set = mimic_test

In [24]:
dataset_sizes = {'train':len(tra_set),'val':len(va_set)}

In [25]:
dataset_sizes

{'train': 16000, 'val': 4000}

In [26]:
train_set = torch.utils.data.DataLoader(tra_set, batch_size=batch_size, shuffle=True, drop_last=True)
val_set = torch.utils.data.DataLoader(va_set, batch_size=batch_size, shuffle=True, drop_last=True)

In [27]:
len(train_set),len(val_set)

(1000, 250)

In [28]:
dataloaders = {'train':train_set,'val':val_set}

In [29]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            label_list = []
            output_list = []

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.unsqueeze(1).to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    labels = labels.type_as(outputs)
#                     _, preds = torch.max(outputs, 1)
                    preds = torch.sigmoid(outputs)>0.5
                    loss = criterion(torch.sigmoid(outputs), labels)
                    
                    for i in range(len(outputs)):
                        output_list.append(torch.sigmoid(outputs[i]).cpu().data.numpy().tolist())
                        label_list.append(labels[i].cpu().data.numpy().tolist())

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
#             fpr, tpr, thresholds = metrics.roc_curve(np.array(label_list), np.array(output_list), pos_label=2)
#             auc = metrics.auc(fpr, tpr)
            auc = roc_auc_score(np.array(label_list),np.array(output_list))

            print('{} Loss: {:.4f} Acc: {:.4f} AUC: {:.4f}'.format(
                phase, epoch_loss, epoch_acc, auc))
#             print('{} Loss: {:.4f} Acc: {:.4f}'.format(
#                 phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [30]:
# criterion = nn.CrossEntropyLoss()
criterion = nn.BCEWithLogitsLoss()
optimizer_ft = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

In [31]:
model_ft = train_model(model, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=10)

Epoch 0/9
----------
train Loss: 0.6947 Acc: 0.5001 AUC: 0.6387
val Loss: 0.6848 Acc: 0.4995 AUC: 0.7861

Epoch 1/9
----------
train Loss: 0.6454 Acc: 0.6597 AUC: 0.8097
val Loss: 0.6116 Acc: 0.7592 AUC: 0.8662

Epoch 2/9
----------
train Loss: 0.6061 Acc: 0.7726 AUC: 0.8679
val Loss: 0.5997 Acc: 0.7895 AUC: 0.8862

Epoch 3/9
----------
train Loss: 0.5975 Acc: 0.7949 AUC: 0.8800
val Loss: 0.5960 Acc: 0.7848 AUC: 0.8855

Epoch 4/9
----------
train Loss: 0.5919 Acc: 0.8036 AUC: 0.8874
val Loss: 0.5913 Acc: 0.8127 AUC: 0.8934

Epoch 5/9
----------
train Loss: 0.5892 Acc: 0.8135 AUC: 0.8909
val Loss: 0.5912 Acc: 0.8153 AUC: 0.8888

Epoch 6/9
----------
train Loss: 0.5869 Acc: 0.8131 AUC: 0.8943
val Loss: 0.5966 Acc: 0.7708 AUC: 0.8898

Epoch 7/9
----------
train Loss: 0.5846 Acc: 0.8083 AUC: 0.9002
val Loss: 0.5893 Acc: 0.8115 AUC: 0.8955

Epoch 8/9
----------
train Loss: 0.5807 Acc: 0.8267 AUC: 0.9045
val Loss: 0.5889 Acc: 0.8115 AUC: 0.8965

Epoch 9/9
----------
train Loss: 0.5798 Acc: 0

In [32]:
# criterion = nn.CrossEntropyLoss()
criterion = nn.BCEWithLogitsLoss(reduction='sum')
optimizer_ft = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

In [33]:
model_ft = train_model(model, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=10)

Epoch 0/9
----------
train Loss: 20.5120 Acc: 0.6707 AUC: 0.7614
val Loss: 20.6878 Acc: 0.7265 AUC: 0.7888

Epoch 1/9
----------
train Loss: 20.6017 Acc: 0.6875 AUC: 0.7455
val Loss: 21.0078 Acc: 0.6208 AUC: 0.7156

Epoch 2/9
----------
train Loss: 20.5911 Acc: 0.6756 AUC: 0.7283
val Loss: 22.0233 Acc: 0.6668 AUC: 0.7060

Epoch 3/9
----------
train Loss: 20.4134 Acc: 0.7023 AUC: 0.7618
val Loss: 20.2730 Acc: 0.7402 AUC: 0.7825

Epoch 4/9
----------
train Loss: 20.3392 Acc: 0.7181 AUC: 0.7670
val Loss: 21.1782 Acc: 0.6302 AUC: 0.7310

Epoch 5/9
----------
train Loss: 20.1432 Acc: 0.7199 AUC: 0.7786
val Loss: 19.8673 Acc: 0.7412 AUC: 0.8109

Epoch 6/9
----------
train Loss: 20.1472 Acc: 0.7167 AUC: 0.7719
val Loss: 20.3524 Acc: 0.6757 AUC: 0.7900

Epoch 7/9
----------
train Loss: 20.1190 Acc: 0.6994 AUC: 0.7798
val Loss: 19.8252 Acc: 0.7362 AUC: 0.8122

Epoch 8/9
----------
train Loss: 19.7819 Acc: 0.7407 AUC: 0.8159
val Loss: 19.7045 Acc: 0.7418 AUC: 0.8223

Epoch 9/9
----------
train L

In [34]:
# criterion = nn.CrossEntropyLoss()
criterion = nn.BCEWithLogitsLoss()
optimizer_ft = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

In [35]:
model_ft = train_model(model, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=300)

Epoch 0/299
----------
train Loss: 0.6120 Acc: 0.7540 AUC: 0.8321
val Loss: 0.6103 Acc: 0.7568 AUC: 0.8305

Epoch 1/299
----------
train Loss: 0.6125 Acc: 0.7575 AUC: 0.8303
val Loss: 0.6089 Acc: 0.7698 AUC: 0.8339

Epoch 2/299
----------
train Loss: 0.6104 Acc: 0.7656 AUC: 0.8342
val Loss: 0.6080 Acc: 0.7770 AUC: 0.8379

Epoch 3/299
----------
train Loss: 0.6105 Acc: 0.7660 AUC: 0.8351
val Loss: 0.6071 Acc: 0.7770 AUC: 0.8423

Epoch 4/299
----------
train Loss: 0.6101 Acc: 0.7666 AUC: 0.8366
val Loss: 0.6054 Acc: 0.7782 AUC: 0.8410

Epoch 5/299
----------
train Loss: 0.6083 Acc: 0.7706 AUC: 0.8399
val Loss: 0.6070 Acc: 0.7840 AUC: 0.8401

Epoch 6/299
----------
train Loss: 0.6081 Acc: 0.7720 AUC: 0.8408
val Loss: 0.6051 Acc: 0.7853 AUC: 0.8451

Epoch 7/299
----------
train Loss: 0.6067 Acc: 0.7752 AUC: 0.8432
val Loss: 0.6068 Acc: 0.7715 AUC: 0.8448

Epoch 8/299
----------
train Loss: 0.6080 Acc: 0.7722 AUC: 0.8427
val Loss: 0.6055 Acc: 0.7790 AUC: 0.8467

Epoch 9/299
----------
train

KeyboardInterrupt: 

In [36]:
batch_size = 16
random_seed = 8

In [37]:
tra_set, va_set = torch.utils.data.random_split(dataset, [tra,val],generator=torch.Generator().manual_seed(random_seed))

In [38]:
train_set = torch.utils.data.DataLoader(tra_set, batch_size=batch_size, shuffle=True, drop_last=True)
val_set = torch.utils.data.DataLoader(va_set, batch_size=batch_size, shuffle=True, drop_last=True)

In [30]:
# criterion = nn.CrossEntropyLoss()
criterion = nn.BCEWithLogitsLoss()
optimizer_ft = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

In [31]:
model_ft = train_model(model, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=100)

Epoch 0/99
----------
train Loss: 0.6652 Acc: 0.5969 AUC: 0.7425
val Loss: 0.6138 Acc: 0.7512 AUC: 0.8601

Epoch 1/99
----------
train Loss: 0.6093 Acc: 0.7642 AUC: 0.8564
val Loss: 0.6003 Acc: 0.7622 AUC: 0.8851

Epoch 2/99
----------
train Loss: 0.5999 Acc: 0.7816 AUC: 0.8693
val Loss: 0.5959 Acc: 0.7915 AUC: 0.8821

Epoch 3/99
----------
train Loss: 0.5950 Acc: 0.7969 AUC: 0.8774
val Loss: 0.5989 Acc: 0.8030 AUC: 0.8848

Epoch 4/99
----------
train Loss: 0.5916 Acc: 0.8074 AUC: 0.8864
val Loss: 0.5895 Acc: 0.8133 AUC: 0.8953

Epoch 5/99
----------
train Loss: 0.5904 Acc: 0.8086 AUC: 0.8857
val Loss: 0.5873 Acc: 0.8227 AUC: 0.8982

Epoch 6/99
----------
train Loss: 0.5875 Acc: 0.8134 AUC: 0.8884
val Loss: 0.5916 Acc: 0.7963 AUC: 0.8877

Epoch 7/99
----------
train Loss: 0.5823 Acc: 0.8186 AUC: 0.8985
val Loss: 0.5864 Acc: 0.8205 AUC: 0.8983

Epoch 8/99
----------
train Loss: 0.5806 Acc: 0.8306 AUC: 0.8998
val Loss: 0.5859 Acc: 0.8240 AUC: 0.8998

Epoch 9/99
----------
train Loss: 0.5

KeyboardInterrupt: 

In [168]:
if n>=1000:
    n = str(int(n/1000))+'k'
if m>=1000:
    m = str(int(m/1000))+'k'
weights_path = '/home/diwenxu2/dataset_gan/mimic_'+str(m)+'_seed'+str(random_seed)+'_v3.ckpt'
torch.save(model_ft.state_dict(), weights_path)
weights_path

'/home/diwenxu2/dataset_gan/mimic_10k_seed8_v3.ckpt'

In [49]:
weights_path

'/home/diwenxu2/dataset_gan/chexpert_30k_seed8_v3.ckpt'

In [None]:
model = models.resnet152(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
model.load_state_dict(torch.load('covid_classifier.ckpt'))

In [None]:
model = model.cuda()

In [None]:
model.eval()

In [None]:
image_path = '/home/diwenxu2/xray_latent2im/model_Covid/stylegan_v2_xray_linear_lr0.0001_l2_w/images/w_1_seed12_Covid_max1.0_min0.0_sample3_0.64.png'

In [None]:
transform_ = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ])

In [None]:
from PIL import Image

# image = Image.open(image_path)

In [None]:
path = '/home/diwenxu2/xray_latent2im/20k_images/'
files = os.listdir(path)

In [None]:
def pred(model,img_path):
    image = Image.open(img_path).convert('RGB')
    transformed_img = transform_(image)
    transformed_img = transformed_img.to('cuda')
    image = transformed_img.unsqueeze(0).cuda()
    output = model(image)
    _,preds = torch.max(output,1)
    return preds

In [None]:
pred(model,'/home/diwenxu2/xray_latent2im/model_Covid/stylegan_v2_xray_linear_lr0.0001_l2_w/images/org_img/w_1_seed12_Covid_max1.0_min0.0_sample2_0.01_org.jpg')

In [None]:
ratio = 0
for i in files[:1000]:
    img_path = path+i
    preds = pred(model,img_path)
    if preds == 1:
        ratio+=1
ratio/=1000
ratio

In [None]:
res = pred(model,path+files[0])

In [None]:
res==1

In [None]:
transformed_img = transform_(image)
transformed_img = transformed_img.to('cuda')
image = transformed_img.unsqueeze(0).cuda()
output = model(image)

In [None]:
img_path = '/home/diwenxu2/xray_latent2im/model_Covid/stylegan_v2_xray_linear_lr0.0001_l2_w/images/w_1_seed12_Covid_max0.0_min0.0_sample4_0.23.png'

In [None]:
image = Image.open(img_path).convert('RGB')
transformed_img = transform_(image)
transformed_img = transformed_img.to('cuda')
image = transformed_img.unsqueeze(0).cuda()
output = model(image)
print(output)

In [None]:
res = model(img)
res

In [None]:
_,preds = torch.max(res,1)
preds

In [None]:
reg_json = '/home/diwenxu2/Chexpert/config/example.json'
with open(reg_json) as f:
    cfg = edict(json.load(f))

In [None]:
dataloader_train = DataLoader(
        ImageDataset(cfg.train_csv, cfg, mode='train'),
        batch_size=batch_size, num_workers=12,
        drop_last=True, shuffle=True)

In [None]:
len(dataloader_train)

In [None]:
label_header = dataloader_train.dataset._label_header

In [None]:
label_header

In [None]:
dataloader_dev = DataLoader(
        ImageDataset(cfg.dev_csv, cfg, mode='dev'),
        batch_size=batch_size, num_workers=4,
        drop_last=False, shuffle=False)

In [None]:
len(dataloader_dev)

In [None]:
dev_header = dataloader_dev.dataset._label_header

In [None]:
dev_header

In [None]:
dataloaders = {'train':dataloader_train,'val':dataloader_dev}

In [None]:
dataiter = iter(train_set)

In [None]:
steps = len(dataloader_train)
steps

In [None]:
images, labels = next(dataiter)

In [None]:
imshow(images[2])

In [None]:
images = images.to(device)
model = model.to(device)

In [None]:
images.size()

In [None]:
outputs = model(images)

In [None]:
outputs.size()

In [None]:
labels.size()

In [None]:
labels[0]

In [None]:
import torch.optim as optim
optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

In [None]:
label = torch.sigmoid(outputs[3].view(-1)).ge(0.5).float()

In [None]:
label

In [None]:
def get_loss(output, target, index, cfg, device):
    for num_class in cfg.num_classes:
        assert num_class == 1
    target = target[:, index].view(-1)
    pos_weight = torch.from_numpy(
        np.array(cfg.pos_weight,
                 dtype=np.float32)).to(device).type_as(target)
    if cfg.batch_weight:
        if target.sum() == 0:
            loss = torch.tensor(0., requires_grad=True).to(device)
        else:
            weight = (target.size()[0] - target.sum()) / target.sum()
            loss = F.binary_cross_entropy_with_logits(
                output[:,index].view(-1), target, pos_weight=weight)
    else:
        loss = F.binary_cross_entropy_with_logits(
            output[:,index].view(-1), target, pos_weight=pos_weight[index])

    return loss

In [None]:
def train_model(model, optimizer, scheduler, device, num_epochs):
    dataset_sizes = {x: len(dataloaders[x]) for x in ['train', 'val']}
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in tqdm(range(num_epochs)):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                loss = 0
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    for t in range(5):
                        loss_t = get_loss(outputs, labels, t, cfg, device)
                        loss += loss_t

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                preds = torch.sigmoid(outputs).ge(0.5).float()
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)/(len(preds)*len(preds[0]))
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model


In [None]:
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
model_ft = train_model(model, optimizer, exp_lr_scheduler,device,
                       num_epochs=15)

In [None]:
torch.save(model_ft.state_dict(), 'vgg19.ckpt')

In [None]:
vgg19 = models.vgg19(pretrained=False)
mod = list(vgg19.classifier.children())
mod.pop()
mod.append(torch.nn.Linear(4096, 2))
new_classifier = torch.nn.Sequential(*mod)
vgg19.classifier = new_classifier

In [None]:
# vgg19 = nn.DataParallel(vgg19)
vgg19.load_state_dict(torch.load('covid_classifier.ckpt'))
vgg19 = vgg19.cuda().features.eval()

In [None]:
model.train()
num_tasks = len(cfg.num_classes)
for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in tqdm(enumerate(dataloader_train, 0)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].cuda(), data[1].cuda()
        outputs = model(inputs)
            
        loss = 0
        for t in range(num_tasks):
            loss_t = get_loss(outputs, labels, t, cfg)
            loss += loss_t
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')