In [1]:
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import numpy as np
import os
from pathlib import Path
import sys
import pandas as pd
from imblearn.over_sampling import RandomOverSampler
import datetime as dt


tri_path = os.environ.get('TRIPATH_DIR')
if tri_path and tri_path not in sys.path:
    sys.path.append(tri_path)
from models.feature_extractor import swin3d_b, swin3d_s
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from torch.utils.tensorboard import SummaryWriter


In [25]:
# # Correct the shape of the arrays
# for path in paths:
#     array = np.load(path)
#     shape = array.shape
#     if shape != (30,30,30):
#         new_array = np.zeros((30,30,30))
#         new_array[:shape[0],:shape[1],:shape[2]] = array
#         np.save(path, new_array)


# Data

In [8]:
data_dir = Path(os.environ['DATA_DIR'])
patch_dir = data_dir / "patches"
paths = list(patch_dir.rglob("*.npy"))
patch_labels = [path.parent.name for path in paths]

In [3]:
pd.Series(patch_labels).value_counts()

apo-ferritin           349
ribosome               321
thyroglobulin          246
beta-galactosidase     110
virus-like-particle    110
beta-amylase            86
Name: count, dtype: int64

In [10]:
train_paths, test_paths, train_labels, test_labels = train_test_split(paths, patch_labels, test_size=0.3, stratify=patch_labels, shuffle=True, random_state=42)
train_paths, val_paths, train_labels, val_labels = train_test_split(train_paths, train_labels, test_size=0.3, stratify=train_labels, shuffle=True, random_state=42)

In [11]:
class PatchDataset(Dataset):
    def __init__(self, paths, patch_labels, train=True):
        self.paths = paths
        self.paths = np.array(self.paths).reshape(-1,1)
        self.le = LabelEncoder().fit(patch_labels)
        self.labels = self.le.transform(patch_labels)
        if train:
            self.over_sampler = RandomOverSampler(random_state=42)
            self.paths, self.labels = self.over_sampler.fit_resample(self.paths, self.labels)

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        patch = np.load(path[0])
        patch = torch.from_numpy(patch).float()
        patch = patch.unsqueeze(0)
        patch = patch.repeat(3,1,1,1)
        label = self.labels[idx]
        return patch, label

In [19]:
training_dataset = PatchDataset(train_paths, train_labels, train=True)
training_dataloader = DataLoader(training_dataset, batch_size=32, shuffle=True)
validation_dataset = PatchDataset(val_paths, val_labels, train=False)
validation_dataloader = DataLoader(validation_dataset, batch_size=32, shuffle=True)

In [20]:
len(training_dataset), len(validation_dataset)

(1026, 257)

In [21]:
len(training_dataloader), len(validation_dataloader)

(33, 9)

# Model

In [30]:
class PatchClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = swin3d_s()
        self.feature_extractor.load_weights()
        self.fc1 = nn.Linear(768, 6)
        # self.fc2 = nn.Linear(512, 256)
        # self.fc3 = nn.Linear(256, 6)
        self.classifier = nn.Sequential(
            self.fc1,
            # nn.ReLU(),
            # self.fc2,
            # nn.ReLU(),
            # self.fc3,
            nn.Softmax(1)
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.classifier(x)
        return x

In [31]:
model = PatchClassifier()
for param in model.feature_extractor.parameters():
    param.requires_grad = False

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

def train_one_epoch(epoch_index, writer):
    running_loss = 0.
    last_loss = 0.

    for i, batch in enumerate(training_dataloader):
        patches, labels = batch
        optimizer.zero_grad()

        output = model(patches)
        loss = criterion(output, labels)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()

        if i % 10 == 9:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_dataloader) + i + 1
            writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.
    return last_loss

Loading pretrained video weights


In [32]:
timestamp = dt.datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.

for epoch in tqdm(range(EPOCHS)):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train()
    avg_loss = train_one_epoch(epoch_number, writer)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_dataloader):
            vinputs, vlabels = vdata
            voutputs = model(vinputs)
            vloss = criterion(voutputs, vlabels)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1

  0%|          | 0/5 [00:00<?, ?it/s]

EPOCH 1:
  batch 10 loss: 0.017923983454704286
  batch 20 loss: 0.01791251838207245


  0%|          | 0/5 [00:34<?, ?it/s]


KeyboardInterrupt: 