In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib.pyplot as plt
import scanpy as sc
import anndata
import torch
import numpy as np 
import pandas as pd
import torch.nn.functional as F
from torch import nn
from torch import optim
from sklearn.model_selection import KFold, StratifiedKFold
from torch.utils.data import DataLoader,TensorDataset,SubsetRandomSampler, WeightedRandomSampler

sc.settings.verbosity = 3
sc.settings.set_figure_params(dpi=80, frameon=False, figsize=(3, 3), facecolor='white')

In [11]:
humdata = np.load('5 x_hum_0831_anno.npz', allow_pickle=True)
x_hum = torch.from_numpy(humdata['x'].astype(np.float32))
y_hum = torch.from_numpy(humdata['y'].astype(np.int32))
y_hum = y_hum.long() - 1

inputSize = x_hum.shape[1]
outputSize = 14
hiddenSize = 16
num_epochs = 30
batch_size = 8
lr = 0.001

sample_size = 8
sample_idx=[]
for t in np.unique(y_hum):
    t_idx = np.where(y_hum==t)[0]
    sample_idx.append(np.random.choice(t_idx, size=sample_size))
sample_idx = np.vstack(sample_idx)
sample_idx = sample_idx.reshape(sample_idx.size)
x_train = x_hum[sample_idx]
y_train = y_hum[sample_idx]

dataset = TensorDataset(x_train, y_train)

In [12]:
class OutputHook(list):
    def __call__(self, module, input, output):
        self.append(output)
        
class ANN(nn.Module):
    def __init__(self):
        super(ANN, self).__init__()
        self.relu = nn.ReLU()
        self.layers = nn.Sequential(
            nn.Linear(inputSize, hiddenSize),
            nn.Dropout(0.5),
            nn.Linear(hiddenSize, outputSize),
            self.relu
        )
            
    def forward(self, x):
        out = self.layers(x)
        if not self.training:
            out = F.softmax(out, dim=1)
        return out

In [13]:
device = "cuda:0"
torch.manual_seed(42)

l1_lambda = 1e-3
loss_fn = nn.CrossEntropyLoss()

def train(model, device, dataloader, loss_fn, optimizer):
    train_loss, train_correct=0.0,0
    model.train()

    for inputs, labels in dataloader:
        inputs,labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        output = model(inputs.float())

        l1_penalty = 0.
        for output in output_hook:
            l1_penalty += torch.norm(output, 1)
        l1_penalty *= l1_lambda

        loss = loss_fn(output,labels) + l1_penalty
        loss.backward()
        optimizer.step()
        output_hook.clear()
        
        train_loss += loss.item() * inputs.size(0)
        scores, predictions = torch.max(output.data, 1)
        train_correct += (predictions == labels).sum().item()

    return train_loss,train_correct

def val(model, device, dataloader, loss_fn):
    valid_loss, val_correct = 0.0, 0
    model.eval()

    for inputs, labels in dataloader:
        inputs,labels = inputs.to(device),labels.to(device)
        output = model(inputs)
        loss = loss_fn(output,labels)
        valid_loss += loss.item()*inputs.size(0)
        scores, predictions = torch.max(output.data,1)
        val_correct += (predictions == labels).sum().item()

    return valid_loss,val_correct


train_loader = DataLoader(dataset, batch_size=batch_size)
test_loader = DataLoader(dataset, batch_size=batch_size)

model = torch.load('ANN_0912_equalSample.pt')
model.to(device)
optimizer = optim.RMSprop(model.parameters(), lr=lr)

output_hook = OutputHook()
model.relu.register_forward_hook(output_hook)


for epoch in range(num_epochs):
    train_loss, train_correct = train(model, device, train_loader, loss_fn, optimizer)
    test_loss, test_correct = val(model, device, test_loader, loss_fn)

    train_loss = train_loss / len(train_loader.sampler)
    train_acc = train_correct / len(train_loader.sampler) * 100
    test_loss = test_loss / len(test_loader.sampler)
    test_acc = test_correct / len(test_loader.sampler) * 100

    print("Epoch:{}/{} AVG Training Loss:{:.3f} AVG Test Loss:{:.3f} AVG Training Acc {:.2f} % AVG Test Acc {:.2f} %".format(
        epoch + 1, num_epochs, train_loss, test_loss, train_acc, test_acc))

Epoch:1/30 AVG Training Loss:2.788 AVG Test Loss:2.228 AVG Training Acc 25.00 % AVG Test Acc 70.45 %
Epoch:2/30 AVG Training Loss:1.657 AVG Test Loss:2.082 AVG Training Acc 51.14 % AVG Test Acc 77.27 %
Epoch:3/30 AVG Training Loss:1.271 AVG Test Loss:2.026 AVG Training Acc 65.91 % AVG Test Acc 82.95 %
Epoch:4/30 AVG Training Loss:1.212 AVG Test Loss:1.978 AVG Training Acc 68.18 % AVG Test Acc 82.95 %
Epoch:5/30 AVG Training Loss:0.939 AVG Test Loss:1.947 AVG Training Acc 77.27 % AVG Test Acc 84.09 %
Epoch:6/30 AVG Training Loss:0.940 AVG Test Loss:1.932 AVG Training Acc 77.27 % AVG Test Acc 85.23 %
Epoch:7/30 AVG Training Loss:0.826 AVG Test Loss:1.914 AVG Training Acc 79.55 % AVG Test Acc 87.50 %
Epoch:8/30 AVG Training Loss:0.713 AVG Test Loss:1.888 AVG Training Acc 87.50 % AVG Test Acc 92.05 %
Epoch:9/30 AVG Training Loss:0.736 AVG Test Loss:1.856 AVG Training Acc 82.95 % AVG Test Acc 93.18 %
Epoch:10/30 AVG Training Loss:0.592 AVG Test Loss:1.844 AVG Training Acc 88.64 % AVG Test A

In [14]:
torch.save(model,'ANN_0912_retrain.pt')