In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import cv2
import os
import torch
import torch.nn as nn
import torch.optim as opt
torch.set_printoptions(linewidth=120)
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
print(torch.__version__)

1.8.1+cu111


In [2]:
training = os.listdir('../train/')
study_label = pd.read_csv('../archive/train_study_level.csv')
image_label = pd.read_csv('../archive/train_image_level.csv')
paths = []
labels = []
dct = pd.read_csv('train.csv', index_col=0)
dct['image'] = dct.image.apply(lambda x: x[:-4])
dct = dct.set_index('study').to_dict()['image']

In [3]:
for index, row in study_label.iterrows():
    name = dct[row['id'].replace('_study', '')] + '.png'
    if name in training:
        paths.append('../train/' + name)
        if row['Negative for Pneumonia'] == 1:
            labels.append(0)
        elif row['Typical Appearance'] == 1:
            labels.append(1)
        elif row['Indeterminate Appearance'] == 1:
            labels.append(2)
        elif row['Atypical Appearance'] == 1:
            labels.append(3)
    else:
        print(name)

In [4]:
pd.Series(labels).value_counts()

1    2855
0    1676
2    1049
3     474
dtype: int64

In [5]:
class Net(nn.Module):
    def __init__(self, out_size, model):
        super(Net, self).__init__()
        if model == 'dense':
            self.model = torchvision.models.densenet121(pretrained=True, **{'drop_rate' : 0.3})
            num_ftrs = self.model.classifier.in_features
            self.model.classifier = nn.Sequential(
                nn.Linear(num_ftrs, out_size),
                nn.Sigmoid()
            )
            
        elif model == 'res':
            self.model = torchvision.models.wide_resnet101_2(pretrained=True)
            num_ftrs = self.model.fc.in_features
            self.model.fc = nn.Sequential(
                nn.Linear(num_ftrs, out_size),
                nn.Sigmoid()
            )
        elif model == 'inception':
            self.model = torchvision.models.inception_v3(pretrained=True, **{"aux_logits": False})
            num_ftrs = self.model.fc.in_features
            self.model.fc = nn.Sequential(
                nn.Linear(num_ftrs, out_size),
                nn.Sigmoid()
            )
    def forward(self, x):
        x = self.model(x)
        return x

In [6]:
model = Net(4, 'inception')
print(model)

Net(
  (model): Inception3(
    (Conv2d_1a_3x3): BasicConv2d(
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (Conv2d_2a_3x3): BasicConv2d(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (Conv2d_2b_3x3): BasicConv2d(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (Conv2d_3b_1x1): BasicConv2d(
      (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (Conv2d_4a_3x3): BasicConv2d(


In [7]:
class classification(nn.Module):
    def __init__(self, paths, labels, size=(512,512), aug=False):
        self.paths = paths
        self.labels = labels
        self.aug = aug
        self.example = []
        self.size =size
    
    def __getitem__(self, idx):
            path = self.paths[idx]
            img = cv2.imread(path)
            R, G, B = cv2.split(img)
            output1_R = cv2.equalizeHist(R)
            output1_G = cv2.equalizeHist(G)
            output1_B = cv2.equalizeHist(B)

            img = cv2.merge((output1_R, output1_G, output1_B))
            img = (img - np.mean(img))/np.std(img)
            img = cv2.resize(img, self.size)
            x = torch.from_numpy(np.array(img)).view((3, self.size[0], self.size[1]))
            x = x.float()
            y = self.labels[idx]
            y = torch.tensor(y)
            return x, y
        
    def __len__(self):
        return len(self.paths)
    
    def get(self):
        return self.example

In [8]:
dataset = classification(paths[:500], labels[:500], size=(512, 512))
print(len(dataset))
train_set, val_set = torch.utils.data.random_split(dataset, [400, 100])

500


In [9]:
class Trainer():
    def __init__(self,model,train_set,test_set,opts):
        self.model = model  # neural net
        # device agnostic code snippet
        self.device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
        print(self.device)
        self.model.to(self.device)
        
        self.epochs = opts['epochs']
        self.optimizer = torch.optim.Adam(model.parameters(), opts['lr'], weight_decay=1e-5, amsgrad=True) # optimizer method for gradient descent
        self.criterion = torch.nn.CrossEntropyLoss()                      # loss function
        self.train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                                        batch_size=opts['batch_size'],
                                                        shuffle=True)
        self.test_loader = torch.utils.data.DataLoader(dataset=test_set,
                                                       batch_size=opts['batch_size'],
                                                       shuffle=False)
        self.tb = SummaryWriter(log_dir='./resruns')
        self.best_loss = 1e10
        
    def train(self):
        for epoch in range(self.epochs):
            self.model.train() #put model in training mode
            self.tr_loss = []
            for i, (data,labels) in tqdm(enumerate(self.train_loader),
                                                   total = len(self.train_loader)):
                data, labels = data.to(self.device),labels.to(self.device)
                self.optimizer.zero_grad()  
                outputs = self.model(data)   
                loss = self.criterion(outputs, labels) 
                loss.backward()                        
                self.optimizer.step()                  
                self.tr_loss.append(loss.item())     
            self.tb.add_scalar("Train Loss", np.mean(self.tr_loss), epoch)
            self.test(epoch) # run through the validation set
        self.tb.close()
            
    def test(self,epoch):
            
            self.model.eval()    # puts model in eval mode - not necessary for this demo but good to know
            self.test_loss = []
            self.test_accuracy = []
            
            for i, (data, labels) in enumerate(self.test_loader):
                
                data, labels = data.to(self.device),labels.to(self.device)
                
                with torch.no_grad():
                    outputs = self.model(data)
                
                _, predicted = torch.max(outputs.data, 1)
                loss = self.criterion(outputs, labels)
                self.test_loss.append(loss.item())
                
                self.test_accuracy.append((predicted == labels).sum().item() / predicted.size(0))
            
            print('epoch: {}, train loss: {}, test loss: {}, test accuracy: {}'.format( 
                  epoch+1, np.mean(self.tr_loss), np.mean(self.test_loss), np.mean(self.test_accuracy)))
            self.tb.add_scalar("Val Acc", np.mean(self.test_accuracy), epoch)
            self.tb.add_scalar("Val Loss", np.mean(self.test_loss), epoch)
            if np.mean(self.test_loss) < self.best_loss:
                self.best_loss = np.mean(self.test_loss)
                #torch.save(self.model.state_dict(), './model_weights/resbest.pt')

In [10]:
opts = {
    'lr': 1e-4,
    'epochs': 60,
    'batch_size': 32
}
train = Trainer(model, train_set, train_set, opts)
train.train()

cuda:1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=13.0), HTML(value='')))


epoch: 1, train loss: 1.3523186903733473, test loss: 1.297748171366178, test accuracy: 0.46634615384615385


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=13.0), HTML(value='')))


epoch: 2, train loss: 1.2331673823870146, test loss: 1.1957289714079637, test accuracy: 0.5793269230769231


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=13.0), HTML(value='')))


epoch: 3, train loss: 1.1114204021600576, test loss: 1.0339606496003957, test accuracy: 0.7283653846153846


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=13.0), HTML(value='')))


epoch: 4, train loss: 0.9872334599494934, test loss: 0.908029134456928, test accuracy: 0.9278846153846154


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=13.0), HTML(value='')))


epoch: 5, train loss: 0.9129250599787786, test loss: 0.8613543189488925, test accuracy: 0.9206730769230769


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=13.0), HTML(value='')))


epoch: 6, train loss: 0.860667714705834, test loss: 0.8333195814719567, test accuracy: 0.9158653846153846


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=13.0), HTML(value='')))


epoch: 7, train loss: 0.8415205616217393, test loss: 0.8125237639133747, test accuracy: 0.9158653846153846


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=13.0), HTML(value='')))


epoch: 8, train loss: 0.8231209929172809, test loss: 0.7983141908278832, test accuracy: 0.9879807692307693


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=13.0), HTML(value='')))


epoch: 9, train loss: 0.8026299063975995, test loss: 0.7824431336843051, test accuracy: 1.0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=13.0), HTML(value='')))


epoch: 10, train loss: 0.7869992989760178, test loss: 0.770764525120075, test accuracy: 1.0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=13.0), HTML(value='')))


epoch: 11, train loss: 0.7727933480189397, test loss: 0.7586588263511658, test accuracy: 1.0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=13.0), HTML(value='')))


epoch: 12, train loss: 0.7632628129078791, test loss: 0.7516614565482507, test accuracy: 1.0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=13.0), HTML(value='')))


epoch: 13, train loss: 0.7585739355820876, test loss: 0.7493679981965286, test accuracy: 1.0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=13.0), HTML(value='')))


epoch: 14, train loss: 0.7554699732707097, test loss: 0.7481390191958501, test accuracy: 1.0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=13.0), HTML(value='')))


epoch: 15, train loss: 0.7577875118989211, test loss: 0.751068679186014, test accuracy: 1.0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=13.0), HTML(value='')))


epoch: 16, train loss: 0.7599801604564373, test loss: 0.7504296348645136, test accuracy: 1.0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=13.0), HTML(value='')))


epoch: 17, train loss: 0.7576320996651282, test loss: 0.7479320626992446, test accuracy: 1.0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=13.0), HTML(value='')))


epoch: 18, train loss: 0.7565754331075228, test loss: 0.7471793477351849, test accuracy: 1.0


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=13.0), HTML(value='')))




KeyboardInterrupt: 