In [None]:
import numpy as np

import pandas as pd

import matplotlib.pyplot as plt

from pathlib import Path

from collections import defaultdict

import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import TensorDataset

import torchvision.datasets as dset
import torchvision.transforms as T

In [None]:
USE_GPU = True
dtype = torch.float32 # We will be using float throughout this tutorial.

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Constant to control how frequently we print train loss.
print_every = 100
print('using device:', device)

In [None]:
class LazyPickleDataset(Dataset):
    def __init__(self, file_path, chunk_size=1000):
        self.file_path = file_path
        self.chunk_size = chunk_size
        self.total_length = self._get_total_length()
        self.current_chunk = None
        self.current_chunk_index = -1
        
    def _get_total_length(self):
        with open(self.file_path, 'rb') as f:
            data = pickle.load(f)
            return len(data)
    
    def _load_chunk(self, chunk_index):
        start = chunk_index * self.chunk_size
        end = min(start + self.chunk_size, self.total_length)
        with open(self.file_path, 'rb') as f:
            data = pickle.load(f)
            return data[start:end]
    
    def __len__(self):
        return self.total_length
    
    def __getitem__(self, index):
        if index < 0 or index >= self.total_length:
            raise IndexError('Index out of range')
        
        chunk_index = index // self.chunk_size
        if chunk_index != self.current_chunk_index:
            self.current_chunk = self._load_chunk(chunk_index)
            self.current_chunk_index = chunk_index
        
        item_index = index % self.chunk_size
        return self.current_chunk[item_index]

In [None]:
dataset = LazyPickleDataset('x.pickle', chunk_size=1000)

In [None]:
meta_data = pd.read_csv('metadata_seurat_onelayer.csv')
y = np.array(meta_data.loc[:,"Celltype"])

In [None]:
y_set = set(y)
print(len(y_set))
labels_dict = defaultdict(int)
label_num = 0
for item in y_set:
    labels_dict[item] = label_num
    label_num += 1

labels = []
for item in y:
    labels.append(labels_dict[item])
labels = np.array(labels)
labels = torch.from_numpy(labels)

In [None]:
TOTAL_DATAPOINTS = 14322
NUM_TRAIN = 12890
NUM_VAL = 716
NUM_TEST = 716

labels = torch.Tensor(labels)

In [None]:
train_set, val_set, test_set = torch.utils.data.random_split(dataset, [NUM_TRAIN, NUM_TEST, NUM_VAL])

loader_train = DataLoader(train_set, batch_size=64)
loader_val = DataLoader(val_set, batch_size=64)
loader_test = DataLoader(test_set, batch_size=64)

In [None]:
def train(model, optimizer, epochs=1, start_epoch=0):
    model = model.to(device=device)
    loss_list = []
    val_acc_list = []
    train_acc_list = []
    for e in range(epochs):
        print(f"epoch: {e+1}/{epochs}")
        for t, (x, y) in enumerate(loader_train):
            x = x.unsqueeze(dim=1)
            model.train()
            x = x.to(device=device, dtype=dtype)
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            loss = F.cross_entropy(scores, y)

            print("iteration of data loader is ", t, " with x of shape " , x.shape(), " and loss of ", loss)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if t % print_every == 0:
                print('Iteration %d, loss = %.4f' % (t, loss.item()))
                loss_list.append(loss.item())
                print("Train Accuracy:")
                train_acc = check_accuracy(loader_train, model)
                print("Validation Accuracy:")
                val_acc = check_accuracy(loader_val, model)
                train_acc_list.append(train_acc)
                val_acc_list.append(val_acc)
                np.savetxt("train.txt", np.array(train_acc_list))
                np.savetxt("val_acc.txt", np.array(val_acc_list))
                np.savetxt("losslist.txt", np.array(loss_list))
                print()

    return loss_list, train_acc_list, val_acc_list
            
def check_accuracy(loader, model):
    num_correct = 0
    num_samples = 0
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.unsqueeze(dim=1)
            x = x.to(device=device, dtype=dtype)
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))
    return acc

In [None]:
model = nn.Sequential(
    nn.Conv1d(1, 128, 19),
    nn.ReLU(),
    nn.BatchNorm1d(128),
    nn.Conv1d(128, 128, 19),
    nn.ReLU(),
    nn.MaxPool1d(2, stride=2),
    nn.BatchNorm1d(128),
    nn.Conv1d(128, 256, 19),
    nn.ReLU(),
    nn.BatchNorm1d(256),
    nn.Conv1d(256, 256, 19),
    nn.ReLU(),
    nn.MaxPool1d(2, stride=2),
    nn.BatchNorm1d(256),
    nn.Flatten(),
    nn.Linear(1499456, 22)
)

learning_rate = 5e-5
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)

loss_list, train_acc_list, val_acc_list = train(model, optimizer, epochs=250)

In [None]:
logistic_regression = nn.Sequential(
    nn.Linear(187541, 22),
)

learning_rate = 5e-5
optimizer = optim.SGD(logistic_regression.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)

loss_list, train_acc_list, val_acc_list = train(logistic_regression, optimizer, epochs=100)


In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
val_acc_list = np.loadtxt('val_acc.txt')
train_acc_list = np.loadtxt('train.txt')
loss_list = np.loadtxt('losslist.txt')

print(len(val_acc_list))
x_loss = [i for i in range(len(loss_list))]
x_acc = [i for i in range(len(val_acc_list)/3)]

x_acc = x_acc[1:]

axs[0].plot(x_loss[2:], loss_list[2:])
axs[0].set_title('Loss vs Iteration')
axs[0].set_xlabel('Iteration')
axs[0].set_ylabel('Loss')

axs[1].plot(x_acc, val_acc_list[1:], label='val')
axs[1].plot(x_acc, train_acc_list[1:], label='train')
axs[1].set_title('Accuracy vs Iteration')
axs[1].set_xlabel('Iteration')
axs[1].set_ylabel('Accuracy')
axs[1].legend()

plt.tight_layout()
plt.show()