In [59]:
import numpy as np
import os
from pathlib import Path
from PIL import Image
import json
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchvision import transforms as T
import torch
import random
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from sklearn.svm import SVC
from sklearn.decomposition import PCA

from CellDataset import CellDataset, moco_transform
from MoCoResNetBackbone import MoCoResNetBackbone

In [60]:
modelPath = Path("/scratch/cv-course-group-5/models/training5/model_epoch30.pth")

gpu = 0

device = torch.device(f"cuda:{gpu}" if torch.cuda.is_available() else "cpu")

model = MoCoResNetBackbone()

model_state_dict = torch.load(modelPath, map_location=device)
model.load_state_dict(model_state_dict)
model.eval()
model.to(device)

with open(Path('train_test_split.json'), 'r') as f:
    _split_data = json.load(f)
val_list = _split_data.get("val", [])
dataset = CellDataset(video_list=val_list[:5], mode='inference')

22447


In [61]:
batch_size = 128
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

all_embeddings = []
all_labels = []

In [62]:
with torch.no_grad():
    for imgs, labels in tqdm(dataloader, desc="Extracting embeddings"):
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.cpu().numpy()  # optional, if tensor
        embeddings = model.encode_query(imgs)  # → (B, 2048)
        embeddings = embeddings.cpu().numpy()

        all_embeddings.append(embeddings)
        all_labels.append(labels)

embeddings = np.concatenate(all_embeddings, axis=0)
labels = np.concatenate(all_labels, axis=0)

Extracting embeddings: 100%|████████████████████████████████████████████████████████████████████████████████████| 176/176 [00:31<00:00,  5.52it/s]


In [63]:
class EmbeddingDataset(Dataset):
    def __init__(self, data_array: np.ndarray, label_array: np.ndarray):
        assert data_array.shape[0] == label_array.shape[0], "Mismatched data and labels"
        self.data = torch.from_numpy(data_array).float()
        self.labels = torch.from_numpy(label_array).byte()  # Use .float() if labels are floats

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

In [64]:

train_dataset = EmbeddingDataset(data_array=embeddings[:int(0.8 * len(embeddings))], label_array=labels[:int(0.8 * len(embeddings))])
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)

val_dataset = EmbeddingDataset(data_array=embeddings[int(0.8 * len(embeddings)):], label_array=labels[int(0.8 * len(embeddings)):])
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)

In [58]:
_, train_labels = train_dataset[:]
_, val_labels = val_dataset[:]

linear_model = nn.Sequential(
    nn.Linear(embeddings.shape[1], 128),
    nn.ReLU(),
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 1),
)

epochs = 5
learning_rate = 0.001

pos_weight = (len(train_labels) - train_labels.sum()) / train_labels.sum()
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = torch.optim.Adam(linear_model.parameters(), lr=learning_rate)

for epoch in range(epochs):
    epoch_loss = []
    linear_model.train()
    for embedding, label in tqdm(train_dataloader, desc=f"Epoch {epoch}", total=len(train_dataloader), ncols=100):
        pred = linear_model(embedding)
        loss = loss_fn(pred, label.unsqueeze(1).float())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss.append(loss.item())
    print(f"Train Loss Epoch {epoch}: {np.mean(epoch_loss)}", end="")
    linear_model.eval()
    val_loss = []
    pred_val = []
    for embedding, label in val_dataloader:
        pred = (torch.sigmoid(linear_model(embedding)) > 0.5).squeeze().int()
        val_loss.append(np.sum([label != pred]) / batch_size)
        pred_val.append(pred.cpu().numpy())
    print(classification_report(np.concatenate(pred_val, axis=0), val_labels))
    print(f"Val Loss Epoch {epoch}: {np.mean(val_loss)}")


Epoch 0: 100%|████████████████████████████████████████████████████| 141/141 [00:28<00:00,  4.99it/s]

Train Loss Epoch 0: 0.027198325845921505




              precision    recall  f1-score   support

           0       0.20      0.03      0.06       976
           1       0.78      0.96      0.86      3514

    accuracy                           0.76      4490
   macro avg       0.49      0.50      0.46      4490
weighted avg       0.66      0.76      0.69      4490

Val Loss Epoch 0: 0.18294270833333334


Epoch 1: 100%|████████████████████████████████████████████████████| 141/141 [00:28<00:00,  5.01it/s]

Train Loss Epoch 1: 0.023411371586646173




              precision    recall  f1-score   support

           0       0.20      0.03      0.05      1026
           1       0.77      0.96      0.86      3464

    accuracy                           0.75      4490
   macro avg       0.48      0.50      0.45      4490
weighted avg       0.64      0.75      0.67      4490

Val Loss Epoch 1: 0.19422743055555555


Epoch 2: 100%|████████████████████████████████████████████████████| 141/141 [00:28<00:00,  5.01it/s]

Train Loss Epoch 2: 0.022691497244006882




              precision    recall  f1-score   support

           0       0.44      0.04      0.07      1777
           1       0.61      0.97      0.74      2713

    accuracy                           0.60      4490
   macro avg       0.52      0.50      0.41      4490
weighted avg       0.54      0.60      0.48      4490

Val Loss Epoch 2: 0.3537326388888889


Epoch 3: 100%|████████████████████████████████████████████████████| 141/141 [00:28<00:00,  4.99it/s]

Train Loss Epoch 3: 0.021671189774646826




              precision    recall  f1-score   support

           0       0.16      0.03      0.04      1000
           1       0.77      0.96      0.86      3490

    accuracy                           0.75      4490
   macro avg       0.47      0.49      0.45      4490
weighted avg       0.64      0.75      0.68      4490

Val Loss Epoch 3: 0.1872829861111111


Epoch 4: 100%|████████████████████████████████████████████████████| 141/141 [00:29<00:00,  4.79it/s]

Train Loss Epoch 4: 0.01990283610206758




              precision    recall  f1-score   support

           0       0.25      0.04      0.07      1010
           1       0.78      0.97      0.86      3480

    accuracy                           0.76      4490
   macro avg       0.51      0.50      0.46      4490
weighted avg       0.66      0.76      0.68      4490

Val Loss Epoch 4: 0.19032118055555555


In [65]:
random_forest = RandomForestClassifier(n_estimators=500, criterion="entropy", class_weight="balanced", n_jobs=8, random_state=42)

embeddings_train, labels_train = train_dataset[:]
embeddings_val, labels_val = val_dataset[:]

random_forest.fit(embeddings_train, labels_train)

0,1,2
,n_estimators,500
,criterion,'entropy'
,max_depth,
,min_samples_split,2
,min_samples_leaf,1
,min_weight_fraction_leaf,0.0
,max_features,'sqrt'
,max_leaf_nodes,
,min_impurity_decrease,0.0
,bootstrap,True


In [66]:
pred_val = random_forest.predict(embeddings_val)

print(classification_report(labels_val, pred_val))

              precision    recall  f1-score   support

           0       0.08      0.01      0.01       161
           1       0.96      1.00      0.98      4329

    accuracy                           0.96      4490
   macro avg       0.52      0.50      0.50      4490
weighted avg       0.93      0.96      0.95      4490



In [67]:
pca = PCA(256)
embeddings_train_reduced = pca.fit_transform(embeddings_train)
embeddings_val_reduced = pca.transform(embeddings_val)

In [68]:
svm = SVC(kernel="rbf", probability=True, class_weight="balanced", random_state=42)

svm.fit(embeddings_train_reduced, labels_train)

pred_val = svm.predict(embeddings_val_reduced)

print(classification_report(labels_val, pred_val))

              precision    recall  f1-score   support

           0       0.16      0.86      0.26       161
           1       0.99      0.83      0.90      4329

    accuracy                           0.83      4490
   macro avg       0.57      0.84      0.58      4490
weighted avg       0.96      0.83      0.88      4490

