In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import string
import glob
from tqdm import tqdm
import cv2
import pickle
from concurrent.futures import ProcessPoolExecutor

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # Convolutional layers
        self.conv1 = nn.Conv2d(6, 3, 5)
        self.conv2 = nn.Conv2d(3, 2, 5)
        
        # Max-pooling layer
        self.pool = nn.MaxPool2d(2, 2)
        
        # Linear layers
        self.fc1 = nn.Linear(125 * 125 * 2, 2000)
        self.fc2 = nn.Linear(2000, 1500)
        self.fc3 = nn.Linear(1500, 1108)

    def forward(self, x):
        # Convolutional layers
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 125 * 125 * 2)
        
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = (self.fc3(x))
        return x

In [3]:
PATH_TRAIN = os.path.join(os.getcwd(), 'data', 'recursion', 'train')
os.listdir(PATH_TRAIN)

['HEPG2-01',
 'HEPG2-02',
 'HEPG2-03',
 'HEPG2-04',
 'HEPG2-05',
 'HEPG2-06',
 'HEPG2-07',
 'HUVEC-01',
 'HUVEC-02',
 'HUVEC-03',
 'HUVEC-04',
 'HUVEC-05',
 'HUVEC-06',
 'HUVEC-07',
 'HUVEC-08',
 'HUVEC-09',
 'HUVEC-10',
 'HUVEC-11',
 'HUVEC-12',
 'HUVEC-13',
 'HUVEC-14',
 'HUVEC-15',
 'HUVEC-16',
 'RPE-01',
 'RPE-02',
 'RPE-03',
 'RPE-04',
 'RPE-05',
 'RPE-06',
 'RPE-07',
 'train.csv',
 'train_controls.csv',
 'U2OS-01',
 'U2OS-02',
 'U2OS-03']

In [4]:
DATA_DIR = os.path.join(os.getcwd(), 'data')
RECURSION_DIR = os.path.join(DATA_DIR, 'recursion')
RECURSION_TRAIN = os.path.join(RECURSION_DIR, 'train')
RECURSION_TEST = os.path.join(RECURSION_DIR, 'test')
CELL_TYPES = ['HEPG2', 'HUVEC', 'RPE', 'U2OS']
PLATES = ['Plate1', 'Plate2', 'Plate3', 'Plate4']
LETTER_TO_IX = {}
for ix, letter in enumerate(string.ascii_uppercase[1:15]):
    LETTER_TO_IX[letter] = ix 
IX_TO_LETTER = {v: k for k, v in LETTER_TO_IX.items()}

def parse_filename(s, full_path=False):
    ''' Returns row, col, site, channel of a string in the format of the kaggle filename. '''
    #first _ is always 3rd index
    if full_path:
        s = s[-13:]
    col = LETTER_TO_IX[s[0]]
    row = int(s[1:3]) - 2
    site = int(s[5:6]) - 1
    channel = int(s[8:9]) - 1
    return row, col, site, channel  

def read_image(path):
    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE).astype(np.float32)/255.0
    return img

def read_parse_image(path):
    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE).astype(np.float32)/255.0
    info = parse_filename(path, True)
    return img, info

In [5]:
class Plate:
    def __init__(self, cell_type, plate_num):
        self.images = np.zeros((22, 14, 2, 6, 512, 512), dtype=np.float32) - 1
        self.labels = np.zeros((22, 14), dtype=np.int32) - 1 
        self.cell_type = cell_type
        self.plate_num = plate_num
        
    def load_images(self, files):
        with ProcessPoolExecutor(max_workers=4) as executor:
            images, indices = zip(*executor.map(read_parse_image, files))
        images = np.array(images)
        indices = np.array(indices)
        self.images[indices[:, 0], indices[:, 1], indices[:, 2], indices[:, 3]] = images
        
    def get_image(self, s):
        ix = parse_filename(s)
        return self.images[ix]
    
    def __repr__(self):
        return f'<Plate: plate_num: {self.plate_num}, cell_type: {self.cell_type}>'

class Experiment: 
    def __init__(self, cell_type, exp_num, split):
        self.cell_type = cell_type
        self.exp_num = exp_num
        self.split = split 
        self.plates = []
        
    def load_plates(self):
        exp_dir = os.path.join(RECURSION_DIR, self.split, '{}-{:02d}'.format(self.cell_type, self.exp_num))
        for i, p in enumerate(PLATES):
            plate = Plate(self.cell_type, i+1)
            plate_dir = os.path.join(exp_dir, p)
            plate_files = os.listdir(plate_dir)
            plate_files = glob.glob(f'{plate_dir}/*.png')
            plate.load_images(plate_files)
            self.plates.append(plate)
        
    def __repr__(self):
        return f'<Experiment: cell_type: {self.cell_type}, exp_num: {self.exp_num}, split: {self.split} >'

In [6]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, root_dir):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
        """
        self.data = pd.read_csv(csv_file)
        self.root = root_dir

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        
        id_, experiment, plate, well, sirna = self.data.iloc[idx]
        image = np.zeros((6, 512, 512))
        
        for s in range(1):
            for c in range(6):
                img_name = f'{well}_s{s+1}_w{c+1}.png'
                path = os.path.join(self.root, experiment, f'Plate{plate}', img_name)
                image[c, :, :] = read_image(path)
        sirna_label = np.array([sirna])
            
        sample = {'image': image, 'label': sirna_label}

        return sample

In [7]:
def load_data(batch_size=4):
    CSV_PATH = os.path.join(RECURSION_TRAIN, 'train.csv')
    trainset = Dataset(CSV_PATH, RECURSION_TRAIN)
    return torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False)
    

def train(net, trainloader, save_path, num_epoch=1):
    '''
    Function: train
    arguments:
        net - CNN model used for training
        trainloader - Torch DataLoader object
        save_path - path to save the trained model
    '''
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    
    # Set up GPU device
    # os.environ['CUDA_VISIBLE_DEVICES'] = "0"
    # torch.cuda.device(0)
    
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # net.to(device)
    device = "cpu"
    
    for epoch in range(num_epoch):
        for i, data in enumerate(trainloader):
            inputs, labels = data['image'].to(device, dtype=torch.float), data['label'].to(device, dtype=torch.long)
            labels = torch.squeeze(labels)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            if i % 10 == 0:
                print(f'Completed {i} iterations of training')
            
    torch.save(net.state_dict(), save_path)

In [None]:
# Get trainloader
trainloader = load_data()

# Create net
net = Net()

# Train the model
SAVE_PATH = os.path.join(os.getcwd(), 'models')
train(net, trainloader, SAVE_PATH, num_epoch=1)

Completed 0 iterations of training
Completed 10 iterations of training
Completed 20 iterations of training
Completed 30 iterations of training
Completed 40 iterations of training
Completed 50 iterations of training
Completed 60 iterations of training
Completed 70 iterations of training
Completed 80 iterations of training
Completed 90 iterations of training
Completed 100 iterations of training
Completed 110 iterations of training
Completed 120 iterations of training
