In [None]:
import torch
import sys
import numpy as np
import os
import yaml
import matplotlib.pyplot as plt
import torchvision

In [None]:
!pip install gdown

In [None]:
def get_file_id_by_model(folder_name):
    file_id = {'resnet18_100-epochs_stl10': '14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF',
             'resnet18_100-epochs_cifar10': '1lc2aoVtrAetGn0PnTkOyFzPCIucOJq7C',
             'resnet50_50-epochs_stl10': '1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu'}
    return file_id.get(folder_name, "Model not found.")

In [None]:
folder_name = 'resnet50_50-epochs_stl10'
file_id = get_file_id_by_model(folder_name)
print(folder_name, file_id)

In [None]:
# !mkdir /content/images
# %cd /content/images
# !pip install -q kaggle
# ! cp /content/kaggle.json ~/.kaggle/
# !kaggle datasets download nih-chest-xrays/data
# !unzip -j data.zip -d .
# !rm data.zip
# %cd /content/

In [None]:
!pip install torchinfo wandb

!wandb login 606ef0ddb19fbf179952be1ae9823b40ec33b3b7
import wandb

In [None]:
# user ="vidura"
# project = "medicap-contrastive"
# run = "2x9w9qwt"

# best_model = wandb.restore('last_checkpoint.pth.tar', run_path="{}/{}/{}".format(user,project,run))

In [None]:
id = wandb.util.generate_id()
print(id)

In [None]:
wandb.init(id=id, project="medicap-contrastive-finetune", entity="raveen_hansika", resume=True)

In [None]:
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets
import torch.nn as nn
import torch
import sys
import numpy as np
import os
import yaml
import matplotlib.pyplot as plt
import torchvision
from torch.utils.data import Dataset
from tqdm import tqdm
from PIL import Image


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)

In [None]:
class DenseNet121(nn.Module):
    """Model modified.

    The architecture of our model is the same as standard DenseNet121
    except the classifier layer which has an additional sigmoid function.

    """
    def __init__(self,pretrained=True, num_classes=14):
        super(DenseNet121, self).__init__()
        self.densenet121 = torchvision.models.densenet121(pretrained=pretrained)

        num_ftrs = self.densenet121.classifier.in_features
        self.densenet121.classifier = nn.Sequential(
                nn.Linear(num_ftrs, num_classes),
                nn.Sigmoid()
            )


    def forward(self, x):
        x = self.densenet121(x)
        return x


CLASS_NAMES = [ 'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
                'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia',"No findings"]
class ContrastiveDataset(Dataset):
    def __init__(self, data_dir, split, transform=None):
        """
        Args:
            data_dir: path to image directory.
            image_list_file: path to the file containing images
                with corresponding labels.
            transform: optional transform to be applied on a sample.
        """
        image_names = []
        labels = []
        with open(split, "r") as f:
            for line in f:
                items = line.split()
                image_name= items[0]
                label = items[1:]
                label = [int(i) for i in label]
                image_name = os.path.join(data_dir, image_name)
                image_names.append(image_name)
                labels.append(label)
        self.image_names = image_names
        self.labels = labels
        self.transform = transform

    def __getitem__(self, index):
        """
        Args:
            index: the index of item

        Returns:
            image and its labels
        """
        image_name = self.image_names[index]
        image = Image.open(image_name).convert('RGB')
        label = self.labels[index]
        if self.transform is not None:
            image = self.transform(image)
        
        return image, torch.FloatTensor(label)

    def __len__(self):
        return len(self.image_names)


In [None]:
from torchvision.transforms import transforms


def get_stl10_data_loaders(download, shuffle=False, batch_size=256):
    train_dataset = datasets.STL10('./data', split='train', download=download,
                                  transform=transforms.ToTensor())

    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=0, drop_last=False, shuffle=shuffle)
  
    test_dataset = datasets.STL10('./data', split='test', download=download,
                                  transform=transforms.ToTensor())

    test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=10, drop_last=False, shuffle=shuffle)
    return train_loader, test_loader

def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):
    train_dataset = datasets.CIFAR10('./data', train=True, download=download,
                                  transform=transforms.ToTensor())

    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=0, drop_last=False, shuffle=shuffle)
  
    test_dataset = datasets.CIFAR10('./data', train=False, download=download,
                                  transform=transforms.ToTensor())

    test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=10, drop_last=False, shuffle=shuffle)
    return train_loader, test_loader

def get_medicap_contrastive_transform(size):
        data_transforms = transforms.Compose([transforms.Resize(size),
                                              transforms.ToTensor()])
        return data_transforms
def get_chestxray_data_loaders(root_folder, shuffle=False, batch_size=256):
    train_dataset = ContrastiveDataset(root_folder,split="/kaggle/input/chexnet-file-list/train_list.txt",transform=get_medicap_contrastive_transform(256))

    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=0, drop_last=False, shuffle=shuffle)
  
    test_dataset = ContrastiveDataset(root_folder,split="/kaggle/input/chexnet-file-list/val_list.txt",transform=get_medicap_contrastive_transform(256))

    test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=10, drop_last=False, shuffle=shuffle)
    return train_loader, test_loader

In [None]:
class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)
        
config = {
  "arch":"chexnet",
  "dataset_name":"chexnet",
  "root_folder":"/kaggle/input/data/",
  "result_dir":"./"
}
config = Struct(**config)

best_valid_loss = np.inf

In [None]:
if config.arch == 'resnet18':
    model = torchvision.models.resnet18(pretrained=False, num_classes=10).to(device)
elif config.arch == 'resnet50':
    model = torchvision.models.resnet50(pretrained=False, num_classes=10).to(device)
elif(config.arch == "chexnet"):
    model = DenseNet121(pretrained=False, num_classes=14)

In [None]:
checkpoint = torch.load('/kaggle/input/chexnet-simclr-model/best_checkpoint.pth.tar', map_location=device)
state_dict = checkpoint['state_dict']

for k in list(state_dict.keys()):

    if k.startswith('backbone.'):
        if k.startswith('backbone') and not (k.startswith('backbone.fc') or k.startswith('backbone.densenet121.classifier')):
            # remove prefix
            state_dict[k[len("backbone."):]] = state_dict[k]
    del state_dict[k]

In [None]:
print('resume model from', checkpoint['epoch'])

In [None]:
log = model.load_state_dict(state_dict, strict=False)
assert log.missing_keys == ['densenet121.classifier.0.weight', 'densenet121.classifier.0.bias']

In [None]:
if config.dataset_name == 'cifar10':
    train_loader, test_loader = get_cifar10_data_loaders(download=True)
elif config.dataset_name == 'stl10':
    train_loader, test_loader = get_stl10_data_loaders(download=True)
elif config.dataset_name == 'chexnet':
    train_loader, test_loader = get_chestxray_data_loaders(config.root_folder)
print("Dataset:", config.dataset_name)

In [None]:
# freeze all layers but the last fc
for name, param in model.named_parameters():
    if name not in ['densenet121.classifier.0.weight', 'densenet121.classifier.0.bias']:
        param.requires_grad = False

parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2  # fc.weight, fc.bias

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)
criterion = torch.nn.BCELoss().to(device)

In [None]:
def accuracy(output, target):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = torch.sum(target,dim=1)
        batch_size = target.size(0)
        pred_labels = []
        for i, x in enumerate(output):
            labels = x.topk(int(maxk[i].item())).indices
            labels = labels.unsqueeze(0)
            pred_labels.append(torch.zeros(labels.size(0), target.size(1)).to(device).scatter_(1, labels, 1.).to(device))
        pred = torch.cat(pred_labels)
        mask = target > 0
        a = torch.masked_select(pred,mask)
        return (torch.sum(a)/torch.sum(target)).item()*100

In [None]:
def save_checkpoint(state, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    wandb.save(filename, policy="now")
    print("save", filename)

In [None]:
epochs = 7
wandb.watch(model)
model = model.to(device)
start_epoch = checkpoint['epoch'] + 1
best_valid_loss = checkpoint['best_valid_loss']
optimizer.load_state_dict(checkpoint['optimizer'])

for epoch in range(start_epoch, epochs + 1):
    print(torch.cuda.current_device(), torch.cuda.get_device_name(0), "epoch", epoch)
    
    top1_train_accuracy = 0
    train_loss =0
    valid_loss =0
    model.train()
    for counter, (x_batch, y_batch) in enumerate(tqdm(train_loader)):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        logits = model(x_batch)
        loss = criterion(logits, y_batch)
        train_loss +=loss.item()
        top1 = accuracy(logits, y_batch)
        top1_train_accuracy += top1

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    top1_train_accuracy /= (counter + 1)
    top1_accuracy = 0
    model.eval()
    with torch.no_grad():
        for counter, (x_batch, y_batch) in enumerate(tqdm(test_loader)):
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)

            logits = model(x_batch)
            loss = criterion(logits, y_batch)
            valid_loss +=loss.item()
            top1 = accuracy(logits, y_batch)
            top1_accuracy += top1
            
    train_loss = train_loss/(len(train_loader))
    valid_loss = valid_loss/(len(test_loader))
    
    if(valid_loss < best_valid_loss):
        best_valid_loss = valid_loss
        save_checkpoint({
                'epoch': epoch,
                'best_valid_loss': best_valid_loss,
                'arch': config.arch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, filename=os.path.join(config.result_dir, 'best_checkpoint.pth.tar'))
        
    save_checkpoint({
                'epoch': epoch,
                'best_valid_loss': best_valid_loss,
                'arch': config.arch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, filename=os.path.join(config.result_dir, 'last_checkpoint.pth.tar'))
    
    top1_accuracy /= (counter + 1)
    wandb.log(
      {"epoch":epoch,
       "train loss":train_loss,
       "valid loss":valid_loss,
       "Top1 Train accuracy":top1_train_accuracy,
       "Top1 Test accuracy":top1_accuracy
       })
    print(f"Epoch {epoch}\tTop1 Train accuracy {top1_train_accuracy}\tTop1 Test accuracy: {top1_accuracy}\tTrain_loss: {train_loss}\tValid_loss: {valid_loss}")