In [None]:
import torch
from torch import nn, optim
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
import numpy as np 
import pandas as pd
from PIL import Image
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
from skimage.io import imread, imsave
import skimage
from sklearn import metrics

In [None]:
normalize = transforms.Normalize(mean=[0.45271412, 0.45271412, 0.45271412],
                                     std=[0.33165374, 0.33165374, 0.33165374])

train_transformer = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop((224),scale=(0.5,1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    normalize
])

val_transformer = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    normalize
])

In [None]:
def read_txt(txt_path):
    with open(txt_path) as f:
        lines = f.readlines()
    txt_data = [line.strip() for line in lines]
    return txt_data

class CovidCTDataset(Dataset):
    def __init__(self, root_dir, txt_COVID, txt_NonCOVID, transform=None):
        """
        Args:
            txt_path (string): Path to the txt file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        File structure:
        - root_dir
            - CT_COVID
                - img1.png
                - img2.png
                - ......
            - CT_NonCOVID
                - img1.png
                - img2.png
                - ......
        """
        self.root_dir = root_dir
        self.txt_path = [txt_COVID,txt_NonCOVID]
        self.classes = ['CT_COVID', 'CT_NonCOVID']
        self.num_cls = len(self.classes)
        self.img_list = []
        for c in range(self.num_cls):
            cls_list = [[os.path.join(self.root_dir,self.classes[c],item), c] for item in read_txt(self.txt_path[c])]
            self.img_list += cls_list
        self.transform = transform

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_path = self.img_list[idx][0]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)
        sample = {'img': image,
                  'label': int(self.img_list[idx][1])}
        return sample

In [None]:
batchsize = 16

trainset = CovidCTDataset(root_dir='/......../COVID-Downstream', #Update
                            txt_COVID='Data-split/COVID/trainCT_COVID.txt',
                            txt_NonCOVID='Data-split/NonCOVID/trainCT_NonCOVID.txt',
                            transform= train_transformer)
valset = CovidCTDataset(root_dir='/............/COVID-Downstream', #Update
                            txt_COVID='Data-split/COVID/valCT_COVID.txt',
                            txt_NonCOVID='Data-split/NonCOVID/valCT_NonCOVID.txt',
                            transform= val_transformer)
testset = CovidCTDataset(root_dir='/.............../Barlow Twins/COVID-Downstream', #Update
                            txt_COVID='Data-split/COVID/testCT_COVID.txt',
                            txt_NonCOVID='Data-split/NonCOVID/testCT_NonCOVID.txt',
                            transform= val_transformer)
print(trainset.__len__())
print(valset.__len__())
print(testset.__len__())

train_loader = DataLoader(trainset, batch_size=batchsize, num_workers = 4, drop_last=True, shuffle=True)
val_loader = DataLoader(valset, batch_size=batchsize, num_workers = 4, drop_last=True, shuffle=False)
test_loader = DataLoader(testset, batch_size=batchsize, num_workers = 4, drop_last=True, shuffle=False)

In [None]:
a = next(iter(train_loader))

img = a['img']
skimage.io.imshow(img[0,1,:,:].numpy())

In [None]:
#training process is defined here
def train(epochs, model):
    print('Starting training..')

    for e in range(0, epochs):
        print('='*16)
        print(f'Starting epoch {e + 1}/{epochs}')
        print('='*16)

        train_loss = 0.
        val_loss = 0.

        model.train() # set model to training phase

        for train_step, data in enumerate(train_loader):


            images = data['img'].to(device)
            labels = data['label'].to(device)
          
            optimizer.zero_grad()
            outputs = model(images)


            #print(outputs.shape)
            loss = loss_fn(outputs, labels.long())
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            if train_step % batchsize == 0:
                print('Evaluating at step', train_step)
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}'.format(e, train_step, 
                       len(train_loader),100.0 * train_step / len(train_loader), loss.item()/ batchsize))
                                

                accuracy = 0
                
                with torch.no_grad():
                    

                    model.eval() # set model to eval phase

                    for val_step, data in enumerate(val_loader):

                        images = data['img'].to(device)
                        labels = data['label'].to(device)

                        outputs = model(images)
                        loss = loss_fn(outputs, labels.long())
                        val_loss += loss.item()

                        _, preds = torch.max(outputs, 1)

                        accuracy += sum((preds == labels))

                    val_loss /= (val_step + 1)
                    accuracy = accuracy/len(valset)

                    print(f'Validation Loss: {val_loss:.4f}, Accuracy: {accuracy:.4f}')

                    #show_preds()

                model.train()

                if accuracy >= 0.98:
                    print('Performance condition satisfied, stopping..')
                    return model 

        train_loss /= (train_step + 1)

        print(f'Training Loss: {train_loss:.4f}')
    print('Training complete..')

    return model

In [None]:
class ResNetModelSSL(nn.Module):

    def __init__(self, path):
        """
        Pass in parsed HyperOptArgumentParser to the model
        :param hparams:
        """
        super().__init__()

        pretrained_dict = torch.load(path)["state_dict"]

        state_dict = {}
        for k, v in pretrained_dict.items():
            if k.startswith("model.network."):
                k = k.replace("model.network.", "")
                state_dict[k] = v

        self.resnet = models.resnet50()
        del self.resnet.fc

        self.resnet.load_state_dict(state_dict)

        self.resnet.fc =  nn.Sequential(
            nn.Linear(2048, 2),
            )


    def forward(self, x):
        logits = self.resnet(x)

        return logits

In [None]:

model_no = []
acc = []
F1 = []
AUC = []

epochs = 300

for i in range(1):

    model = ResNetModelSSL('last.ckpt')

    if torch.cuda.is_available():
        device = torch.device('cuda')

    model = model.to(device)
    loss_fn = torch.nn.CrossEntropyLoss()
    loss_fn = loss_fn.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

    model = train(epochs, model)

    model = model.to('cpu')
    model.eval() # set model to eval phase
    y_true=[]
    y_pred=[]
    for test_step, data in enumerate(test_loader):

        outputs = model(data['img'])

        _, y = torch.max(outputs, 1)

        y_true.append(data['label'])
        y_pred.append(y)


    Y_true = torch.flatten(torch.stack(y_true))
    Y_pred = torch.flatten(torch.stack(y_pred))


    model_no.append(i)

    F1.append(metrics.f1_score(Y_true,Y_pred))
    AUC.append(roc_auc_score(Y_true, Y_pred))
    acc.append(np.mean(Y_true.numpy()==Y_pred.numpy()))

    print('='*16)
    print('Training the Model number {}'.format(i+1))


data = {'model': model_no,
         'F1': F1,
         'AUC': AUC,
         'acc': acc}   

df = pd.DataFrame(data)
df.to_csv('results.csv')