In [1]:
import torch
import torchvision.transforms as transforms
import numpy as np
import random
from torch.utils.data import Dataset, random_split, DataLoader
import pandas as pd
from PIL import Image
import os
import matplotlib.pyplot as plt
from matplotlib import cm

import torchvision.models as models

DATASET_PATH = 'data/train'
TRAIN_CSV = 'data/train.csv'

# Fix the random seed for reproducibility
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
torch.backends.cudnn.deterministic=True
torch.backends.cudnn.benchmark = False

text_labels = {
0:  "Nucleoplasm", 
1:  "Nuclear membrane",   
2:  "Nucleoli",   
3:  "Nucleoli fibrillar center" ,  
4:  "Nuclear speckles",
5:  "Nuclear bodies",
6:  "Endoplasmic reticulum",   
7:  "Golgi apparatus",
8:  "Peroxisomes",
9:  "Endosomes",
10:  "Lysosomes",
11:  "Intermediate filaments",   
12:  "Actin filaments",
13:  "Focal adhesion sites",   
14:  "Microtubules",
15:  "Microtubule ends",   
16:  "Cytokinetic bridge",   
17:  "Mitotic spindle",
18:  "Microtubule organizing center",  
19:  "Centrosome",
20:  "Lipid droplets",   
21:  "Plasma membrane",   
22:  "Cell junctions", 
23:  "Mitochondria",
24:  "Aggresome",
25:  "Cytosol",
26:  "Cytoplasmic bodies",   
27:  "Rods & rings" 
}


NUM_LABELS=len(text_labels)



In [2]:
# Convert labels to binary arrays
def encode_label(label):
    target = torch.zeros(NUM_LABELS)
    for l in str(label).split(' '):
        target[int(l)] = 1.
    return target

def decode_target(target, text_labels=False, threshold=0.5):
    result = []
    for i, x in enumerate(target):
        if (x >= threshold):
            if text_labels:
                result.append(labels[i] + "(" + str(i) + ")")
            else:
                result.append(str(i))
    return ' '.join(result)

In [3]:
# Class to load the dataset
class HumanProteinDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.transform = transform
        self.root_dir = root_dir
        self.data = pd.read_csv(csv_file)
        
    def __len__(self):
        return len(self.data)    
    
    def __getitem__(self, idx):
        row = self.data.loc[idx]
        img_id, img_label = row['Id'], row['Target']
        img = self.open_rgby(img_id)
        # Resizes image and converts to tensor
        if self.transform:
            img = self.transform(img)
        return img, encode_label(img_label)
    
    def open_rgby(self, id): #a function that reads RGBY image
        colors = ['red','green','blue','yellow']
        # merges the four filters of each id into a PIL
        return Image.merge('RGBA', [Image.open(f"{self.root_dir}/{id}_{f}.png") for f in colors])

In [4]:
def load_data(batch_size):
    data = HumanProteinDataset(
        csv_file = TRAIN_CSV,
        root_dir= DATASET_PATH, 
        transform=transforms.Compose([transforms.Resize(256),transforms.ToTensor()])
    )
    # 10 percent of training data are utilized as a test set
    test_pct = 0.1
    test_size = int(test_pct * len(data))
    train_size = len(data) - test_size

    train_data, test_data = random_split(data, [train_size, test_size])
    
    train_dataloader = torch.utils.data.DataLoader(
        train_data, 
        batch_size=batch_size,
        shuffle=True, 
        num_workers=0
    )
    
    test_dataloader = torch.utils.data.DataLoader(
        test_data, 
        batch_size=batch_size,
        shuffle=True, 
        num_workers=0
    )
    return train_dataloader, test_dataloader

In [5]:
# Train NN
def test(test_dataloader, model):
    '''
    This function will test the model performance using testing data.
    '''
    with torch.no_grad():
        accu_number = 0.0
        for x, y in test_dataloader:
            x, y = x.to(device), y.to(device)
            predicted_class = model(x)
            accu_number += torch.sum(predicted_class == y)
        print('testing accuracy: %.4f' % (accu_number/len(test_dataloader.dataset)))
        return (accu_number/len(test_dataloader.dataset)).item()
        
def train(dataloader, model, loss_fn, optimizer):
    '''
    This function will conduct one-epoch training.
    '''
    count = 0
    train_acc = 0
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        pred = model(x)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_acc += torch.sum(pred == y)
    print('training accuracy: %.4f' % (train_acc/len(dataloader.dataset)))
    return (train_acc/len(dataloader.dataset)).item()
        
class Network(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            # Start with 4 channels due to the layering of the image merge
            torch.nn.Conv2d(4, 32, kernel_size=3, stride =2, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, 2),
            
            torch.nn.Conv2d(32, 64, kernel_size=3, stride =2, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, 2),
            
            torch.nn.Conv2d(64, 128, kernel_size=3, stride =2, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, 2),
            
            torch.nn.Conv2d(128, 256, kernel_size=3, stride =2, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, 2),
            
            torch.nn.Flatten(), 
            torch.nn.Linear(256, 64),
            torch.nn.Linear(64, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 28),
            # The sigmoid function returns multiple labels rather than just selecting one
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        pred = self.layers(x)
        return pred
    
# class Resnet34(torch.nn.Module):
#     # Preloaded Resnet34 network to speed up training
#     def __init__(self):
#         super().__init__()
#         self.network = models.resnet34(pretrained=True)
#         # weight for RGB is from Resnet34, weight for Y is set to mean(weight of RGB)
#         weight = self.network.conv1.weight.clone()
#         self.network.conv1 = torch.nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
#         with torch.no_grad():
#             self.network.conv1.weight[:,:3] = weight
#             self.network.conv1.weight[:, 3] = torch.mean(weight, dim=1)
#         # update out_features to NUM_LABELS
#         in_features = self.network.fc.in_features
#         self.network.fc = torch.nn.Linear(in_features, NUM_LABELS)
        
            
#     def forward(self, xb):
#         return torch.sigmoid(self.network(xb))
    


    
    

In [6]:
# Plot the training and test error

def plot(train_loss, test_loss, epochs):
    print(train_loss)
    epochs = range(0,epochs-1)
    plt.plot(epochs, train_loss, 'g', label='Training loss')
    plt.plot(epochs, test_loss, 'b', label='Test loss')
    plt.title('Training and Test loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

In [7]:
def main(hyper_param):
    #model = Resnet34().to(device)
    model = Network().to(device)
    train_dataloader, test_dataloader = load_data(hyper_param['batch_size'])
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = getattr(torch.optim, hyper_param['optimizer'])(
        model.parameters(), 
        **hyper_param['optim_param'] ## keyword unpacking
    )
    train_loss = []
    test_loss = []
    for t in range(hyper_param['n_epochs']):
        print(f"Epoch {t}", end=' ')
        loss = train(train_dataloader, model, loss_fn, optimizer)
        train_loss.append(loss)
        print(train_loss)
        loss = test(test_dataloader, model)
        test_loss.append(loss)
    plot(train_loss, test_loss, hyper_param['n_epochs'])

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Hyper-parameters
hyper_param = {
    'batch_size': 100,
    'n_epochs':5,
    'optimizer': 'SGD',
    'optim_param': {
        ## This dict should be changed according to the selection of optimizer ##
        'lr': 0.06
    }
}

# define an instance of neuralnetworks
#main(model, hyper_param)
main(hyper_param)

Epoch 0 

KeyboardInterrupt: 