## Installs

In [1]:
!pip3 install torch
!pip3 install torchvision
!pip3 install tqdm



## Imports

In [0]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from torchvision import transforms, utils, datasets
import torchvision.models as models
from tqdm import tqdm
from torch.nn.parameter import Parameter
import pdb
import torchvision
import os
import sys
import gzip
import tarfile
import gc
from PIL import Image
import pandas as pd
from skimage import io, transform
import matplotlib.pyplot as plt
import re

from IPython.core.ultratb import AutoFormattedTB

assert torch.cuda.is_available()

#Get Data

In [3]:
# Load data
import zipfile
root = "."
datasets.utils.download_url('https://jspen14-data.s3.amazonaws.com/sample.zip', root, 'sample.zip', None)

with zipfile.ZipFile("sample.zip","r") as zip_ref:
  zip_ref.extractall("data/")

Downloading https://jspen14-data.s3.amazonaws.com/sample.zip to ./sample.zip


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

# Change Directory Structure

In [0]:
mv data/jsDataSample/* .

# Dataloader

In [0]:
class RecursionDataset(Dataset):
    """Recursion Dataset for Big Data Capstone."""

    def __init__(self, csv_file1, root_dir, csv_file2=None, transform=None):
        """
        Args:
            csv_file1 (string): Path to the csv file with most annotations.
            root_dir (string): Directory with all the batch folders containing images.
            csv_file2 (string): Path to the csv file with control annotations.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.transform = transform

        self.csv = pd.read_csv(csv_file1)
        if csv_file2 != None:
            csv2 = pd.read_csv(csv_file2).loc[:,'id_code':'sirna']
            self.csv = pd.concat([self.csv, csv2])\
                         .reset_index(drop=True)
        self.csv['plate'] = 'Plate'+self.csv['plate'].astype(str) # Mimic folder naming for loading pics later
        
        # Create variable for both sites 1 and 2 of each well
        self.csv['site'] = 's1'
        csv_copy = self.csv.copy()
        csv_copy['site'] = 's2'
        self.csv = pd.concat([self.csv, csv_copy])\
                     .sort_values(['id_code', 'site'])\
                     .reset_index(drop=True)

        # Missing pictures that must be removed from csv file
        missingPics = [['HUVEC-06_1_B18', 's2'], ['RPE-04_3_E04', 's1']]
        for pic in missingPics:
          self.csv = self.csv[~((self.csv['id_code'] == pic[0]) & (self.csv['site'] == pic[1]))]

        #NOTE: FILTERING TO ONLY THE PICS OF TRAINING BATCH HEPG2-01
        self.csv = self.csv[self.csv.experiment=='HEPG2-01']
        #NOTE: RENAMING SIRNA CLASSES TO RANGE FROM [0,C-1]
        self.size = 64# NOTE
        self.csv = self.csv[:self.size]
        self.seen = {}
        new_class = 0
        for i in range(self.csv.shape[0]):
          sirna = self.csv['sirna'][i]
          if sirna not in self.seen:
            self.seen[sirna] = str(new_class)
            self.csv['sirna'][i] = str(new_class)
            new_class += 1
          else:
            self.csv['sirna'][i] = self.seen[sirna]
        self.num_classes = len(self.seen)

        self.root_dir = root_dir
        # The mean and stds for each of the channels
        self.GLOBAL_PIXEL_STATS = (torch.tensor([6.74696984, 14.74640167, 10.51260864,
                                             10.45369445,  5.49959796, 9.81545561]).reshape((-1,1,1)),
                                   torch.tensor([7.95876312, 12.17305868, 5.86172946,
                                             7.83451711, 4.701167, 5.43130431]).reshape((-1,1,1)))
    
    def __len__(self):
        return self.size #self.csv.shape[0]

    def __getitem__(self, idx):
        # Generate full filename of image from csv file row info
        pathParts = self.csv.iloc[idx,:]
        pathGen = os.path.join(self.root_dir, pathParts['experiment'], pathParts['plate'])
        filenameGen = pathParts['well']+'_'+pathParts['site']+'_w'
        for i in range(1,7):
            filenameFull = filenameGen+str(i)+'.png'
            pathFull = os.path.join(pathGen, filenameFull)
            image = io.imread(pathFull)
            if i == 1:
                totalTensor = torch.from_numpy(image).unsqueeze(0)
            else:
                imageTensor = torch.from_numpy(image).unsqueeze(0)
                totalTensor = torch.cat( (totalTensor, imageTensor), 0)
        
        try:
            sirna = self.csv.iloc[idx,:].loc['sirna']
        except:
            sirna = -2
        
        if sirna=='UNTREATED': sirna = -1
        else: sirna = float(re.search('[0-9]+', sirna).group())
        sirnaTensor = torch.tensor([sirna])

        # Normalize data using global pixel stats
        totalTensor = (totalTensor-self.GLOBAL_PIXEL_STATS[0]) / self.GLOBAL_PIXEL_STATS[1]

        # Apply transformation
        if self.transform != None:
            toPil = transforms.ToPILImage()
            randTransform = transforms.RandomOrder(self.transform)
            toTensor = transforms.ToTensor()
            tensorHalves = torch.split(totalTensor, 3, dim=0)
            for half in tensorHalves:
              half = toTensor(randTransform(toPil(half)))
            totalTensor = torch.cat(tensorHalves, dim=0)

        return totalTensor.float(), sirnaTensor.float()

In [6]:
rotations = [transforms.RandomRotation((90,90)),transforms.RandomRotation((180,180)),transforms.RandomRotation((270,270)), transforms.RandomRotation((0,0))]
transformList = [transforms.RandomHorizontalFlip(), transforms.RandomChoice(rotations)]
test_dataset = RecursionDataset(csv_file1='train-labels/train.csv',
                                root_dir='train-data',
                                csv_file2='train-labels/train_controls.csv',
                                transform=transformList)

model = models.resnet18(pretrained=True)

weights = model.conv1.weight
model.conv1 = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.fc = nn.Linear(in_features=512, out_features=test_dataset.num_classes, bias=True)
#model.conv1.weights = torch.cat((weights,weights), dim=1)

'''for param in model.named_parameters():
  if param[1].dim() > 1:
    nn.init.xavier_normal_(param[1])
  else:
    nn.init.normal_(param[1])'''

model.cuda()
objective = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4) #optim.SGD(model.parameters(), lr=.0001, momentum=0.9, weight_decay=1e-4, nesterov=True) #

losses = []
full_loss = []
full_acc = []

scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=1, T_mult=1)
loss_report = 0
acc_report = 0
dataloader = DataLoader(test_dataset, batch_size=4, pin_memory=True, shuffle=True)
iters = len(dataloader)
for epoch in range(5):
  loop = tqdm(total=len(dataloader), position=0, file=sys.stdout)

  for batch, (x, y_truth) in enumerate(dataloader):
    
    x, y_truth = x.cuda(async=True), y_truth.cuda(async=True)
    y_truth = y_truth.type(torch.cuda.LongTensor).squeeze(1) #NOTE: making y_hat a 1D tensor for crossEnropyLoss function

    #scheduler.step(epoch + batch / iters)
    optimizer.zero_grad()
    y_hat = model(x)

    loss = objective(y_hat, y_truth)
    
    loss.backward()
    
    losses.append(loss)

    loop.set_description('epoch:{}, batch loss:{:.4f}, avg train loss:{:.3f}, train acc:{:.6f}, lr: {:.6f}'.format(epoch, loss, loss_report, acc_report, optimizer.param_groups[0]['lr']))
    loop.update(1)

    optimizer.step()

  loop.close()

  print('\nGround Truth:', ' '.join('%5s' % y_truth[j].item() for j in range(4)))
  _, predicted = torch.max(y_hat, 1)
  print('Predicted: ', ' '.join('%5s' % predicted[j].item() for j in range(4)))

total = 0
correct = 0
loop = tqdm(total=len(dataloader), position=0)
for images, labels in dataloader:
  images, labels = images.cuda(async=True), labels.cuda(async=True)
  labels = labels.type(torch.cuda.LongTensor).squeeze(1)
  y_hat = model(x)
  _, predicted = torch.max(y_hat, 1)
  total += labels.size(0)
  correct += (predicted == y_truth).sum().item()
  loop.update(1)
loop.close()
acc_report = correct/total
print('Accuracy:', acc_report)
  
gc.collect()
print('GPU Mem Used:', torch.cuda.memory_allocated(0) / 1e9)

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/checkpoints/resnet18-5c106cde.pth


HBox(children=(IntProgress(value=0, max=46827520), HTML(value='')))


epoch:0, batch loss:3.3736, avg train loss:0.000, train acc:0.000000, lr: 0.000100: 100%|██████████| 16/16 [00:05<00:00,  2.89it/s]

Ground Truth:    11    20    27    11
Predicted:      8    13    26     6
epoch:1, batch loss:2.7352, avg train loss:0.000, train acc:0.000000, lr: 0.000100: 100%|██████████| 16/16 [00:05<00:00,  3.11it/s]

Ground Truth:    19    23     1    26
Predicted:     19     0    12    31
epoch:2, batch loss:2.4186, avg train loss:0.000, train acc:0.000000, lr: 0.000100: 100%|██████████| 16/16 [00:05<00:00,  3.14it/s]

Ground Truth:    10    16     7     9
Predicted:     15    16    31     9
epoch:3, batch loss:2.8778, avg train loss:0.000, train acc:0.000000, lr: 0.000100: 100%|██████████| 16/16 [00:05<00:00,  3.16it/s]

Ground Truth:    24     8     7     1
Predicted:     27    23     5     1
epoch:4, batch loss:1.4208, avg train loss:0.000, train acc:0.000000, lr: 0.000100: 100%|██████████| 16/16 [00:05<00:00,  3.16it/s]

  0%|          | 0/16 [00:00<?, ?it/s]



Ground Truth:    15     2     8    24
Predicted:     15     2     8     7


100%|██████████| 16/16 [00:04<00:00,  3.96it/s]

Accuracy: 0.75
GPU Mem Used: 0.686641664



