In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from sklearn.metrics import confusion_matrix


In [2]:
cells = np.load("cell_channels.npy", allow_pickle=True)
# cells = cells[:2000]
labels = np.load("cell_labels.npy", allow_pickle=True)
# labels = labels[:2000]

In [3]:
batch_size = 32
lr = 0.01
num_epochs = 10

In [4]:
class CellData(Dataset):
    def __init__(self, cells, labels):
        self.cells = cells
        self.labels = labels
    
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        x = self.cells[index]
        y = self.labels[index]
        return x, y

In [5]:
cell_data = CellData(cells, labels)
device = torch.device('cuda')
train_size = int(0.8 * len(cell_data))
test_size = len(cell_data) - train_size
train_set, test_set = torch.utils.data.random_split(cell_data, [train_size, test_size])
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)
classes = np.unique(labels)

In [6]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(51, 32, kernel_size=(3, 3), stride=1, padding=1)
        self.act1 = nn.ReLU()
        self.drop = nn.Dropout(0.1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=(3, 3), stride=1, padding=1)
        self.act2 = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=(2, 2))

        self.fc1 = nn.Linear(64 * 32 * 32, 256)
        self.act3 = nn.ReLU()
        self.fc2 = nn.Linear(256, 18)
        # self.softmax = nn.Softmax

    def forward(self, x):
        x = self.act1(self.conv1(x))
        x = self.drop(x)
        x = self.act2(self.conv2(x))
        x = self.pool(x)
        x = x.reshape(-1, 64 * 32 * 32)
        x = self.act3(self.fc1(x))
        # x = self.softmax(self.fc2(x))
        x = self.fc2(x)
        return x


In [7]:
model = CNN()
model.double()
lossfn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.1)

model = model.to(device)

After running for 10 epochs, I noticed that the loss/accuracy already plateaued, so I didn't train it any farther.

In [8]:
for epoch in range(num_epochs):
    print("Epoch: {} of {}".format(epoch, num_epochs))
    train_loss = 0
    train_correct = 0
    for inputs, targets in tqdm(train_loader):
        inputs = inputs.to(device)
        targets = targets.to(device)
        inputs = inputs.permute(0, 3, 1, 2)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = lossfn(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        train_correct += (predicted == targets).sum().item()
    train_loss /= len(train_loader)
    train_acc = train_correct / len(train_set)
    print("Training loss: {:.4f}".format(train_loss))
    print("Training accuracy: {:.2f}".format(train_acc*100))

    model.eval()
    test_loss = 0
    test_correct = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            inputs = inputs.permute(0, 3, 1, 2)
            outputs = model(inputs)
            loss = lossfn(outputs, targets)
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            test_correct += (predicted == targets).sum().item()
    test_loss /= len(test_loader)
    test_acc = test_correct / len(test_set)
    print("Test loss: {:.4f}".format(test_loss))
    print("Test accuracy: {:.2f}".format(test_acc*100))


Epoch: 0 of 10


100%|█████████████████████████████████████████| 133/133 [00:31<00:00,  4.26it/s]


Training loss: 2918.4842
Training accuracy: 35.09
Test loss: 2.1517
Test accuracy: 23.84
Epoch: 1 of 10


100%|█████████████████████████████████████████| 133/133 [00:27<00:00,  4.85it/s]


Training loss: 1.8964
Training accuracy: 41.81
Test loss: 1.7927
Test accuracy: 44.84
Epoch: 2 of 10


100%|█████████████████████████████████████████| 133/133 [00:27<00:00,  4.81it/s]


Training loss: 1.8371
Training accuracy: 43.97
Test loss: 1.8232
Test accuracy: 44.84
Epoch: 3 of 10


100%|█████████████████████████████████████████| 133/133 [00:27<00:00,  4.78it/s]


Training loss: 1.8226
Training accuracy: 43.97
Test loss: 1.7929
Test accuracy: 44.84
Epoch: 4 of 10


100%|█████████████████████████████████████████| 133/133 [00:27<00:00,  4.77it/s]


Training loss: 1.8295
Training accuracy: 43.97
Test loss: 1.8752
Test accuracy: 44.84
Epoch: 5 of 10


100%|█████████████████████████████████████████| 133/133 [00:27<00:00,  4.78it/s]


Training loss: 1.8241
Training accuracy: 43.97
Test loss: 1.9260
Test accuracy: 44.84
Epoch: 6 of 10


100%|█████████████████████████████████████████| 133/133 [00:27<00:00,  4.77it/s]


Training loss: 1.8163
Training accuracy: 43.97
Test loss: 1.7984
Test accuracy: 44.84
Epoch: 7 of 10


100%|█████████████████████████████████████████| 133/133 [00:27<00:00,  4.78it/s]


Training loss: 1.8197
Training accuracy: 43.97
Test loss: 1.8007
Test accuracy: 44.84
Epoch: 8 of 10


100%|█████████████████████████████████████████| 133/133 [00:27<00:00,  4.78it/s]


Training loss: 1.8245
Training accuracy: 43.97
Test loss: 1.8245
Test accuracy: 44.84
Epoch: 9 of 10


100%|█████████████████████████████████████████| 133/133 [00:27<00:00,  4.78it/s]


Training loss: 1.8213
Training accuracy: 43.97
Test loss: 1.7998
Test accuracy: 44.84


In [9]:
torch.save(model.state_dict(), "model_weights.pt")

In [10]:
model = CNN()
model.double()
model.load_state_dict(torch.load("model_weights.pt"))
model.eval()
test_pred = []
test_true = []
with torch.no_grad():
    for inputs, targets in tqdm(test_loader):
        inputs = inputs.permute(0, 3, 1, 2)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        test_true.extend(targets.numpy())
        test_pred.extend(predicted.numpy())

test_pred = np.array(test_pred)
test_true = np.array(test_true)


100%|███████████████████████████████████████████| 34/34 [00:20<00:00,  1.69it/s]


In [12]:
con_mat = confusion_matrix(test_true, test_pred, labels=range(18))
np.save('confusion_matrix.npy', con_mat)