# Train, test and evaluation

In [1]:
import os, torch

import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torch.nn.functional as F
import matplotlib.pyplot as plt

from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR
from torchvision import transforms
from torchvision.models import vit_b_16  # ViT base, 16x16 patches
from collections import Counter
from loguru import logger
from sklearn.metrics import accuracy_score, f1_score, classification_report, ConfusionMatrixDisplay, confusion_matrix

from embryo_project.config import PROCESSED_DATA_DIR, MODELS_DIR, REPORTS_DIR, FIGURES_DIR

[32m2025-08-11 15:27:28.081[0m | [1mINFO    [0m | [36membryo_project.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: C:\Users\Molinari\Desktop\embryo-project[0m


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

[32m2025-08-11 15:27:28.104[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mUsing device: cuda[0m


In [3]:
class EmbryoSequenceDataset(Dataset):
    def __init__(self, root_dir, transform=None, max_seq_len=None):
        self.root_dir = root_dir
        self.transform = transform
        self.folders = sorted([f for f in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, f))])
        self.max_seq_len = max_seq_len

    def __len__(self):
        return len(self.folders)

    def __getitem__(self, idx):
        folder = self.folders[idx]
        folder_path = os.path.join(self.root_dir, folder)

        image_files = sorted(
            [f for f in os.listdir(folder_path) if f.endswith('.JPG')],
            key=lambda x: int(x.split('_')[-1].split('.')[0])  # extract index
        )

        images = []
        for img_file in image_files:
            img_path = os.path.join(folder_path, img_file)
            img = Image.open(img_path).convert("RGB")
            if self.transform:
                img = self.transform(img)
            images.append(img)

        if self.max_seq_len:
            images = images[:self.max_seq_len]
            while len(images) < self.max_seq_len:
                images.append(torch.zeros_like(images[0]))  # zero padding

        images_tensor = torch.stack(images)  # [#images, #channels, heigth, weigth]

        # 1 if any image has "_1.JPG", else 0
        label = 1 if any(f.endswith('_1.JPG') for f in image_files) else 0

        return images_tensor, label

In [4]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

MAX_SEQ_LEN = 20 # all 20 images
# MAX_SEQ_LEN = 10 # only 10 images

train_dataset = EmbryoSequenceDataset(PROCESSED_DATA_DIR / "train", transform=transform, max_seq_len=MAX_SEQ_LEN)
val_dataset = EmbryoSequenceDataset(PROCESSED_DATA_DIR / "val", transform=transform, max_seq_len=MAX_SEQ_LEN)
test_dataset = EmbryoSequenceDataset(PROCESSED_DATA_DIR / "test", transform=transform, max_seq_len=MAX_SEQ_LEN)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)


In [5]:
class ResNet18LSTM(nn.Module):
    def __init__(self, cnn_embed_dim=512, lstm_hidden_size=128, num_layers=1, bidirectional=True):
        super(ResNet18LSTM, self).__init__()

        # resnet18
        resnet = models.resnet18(pretrained=True)
        modules = list(resnet.children())[:-1]  # remove final
        self.cnn = nn.Sequential(*modules)
        self.cnn_embed_dim = cnn_embed_dim  # 512 for resnet18

        # LSTM
        self.lstm = nn.LSTM(
            input_size=cnn_embed_dim,
            hidden_size=lstm_hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=bidirectional
        )

        # binary classification
        direction_factor = 2 if bidirectional else 1
        self.classifier = nn.Linear(lstm_hidden_size * direction_factor, 1)

    def forward(self, x):
        # B = number of folders inside the batch
        # T = number of images in the folder
        # C = number of image channels
        # H = image height
        # W = image width
        B, T, C, H, W = x.size()

        x = x.view(B * T, C, H, W)
        cnn_feats = self.cnn(x).view(B, T, -1)  # [B*T, 512] -> [B, T, 512]

        lstm_out, _ = self.lstm(cnn_feats)  # [B, T, H]
        last_output = lstm_out[:, -1, :]    # last time step [B, H]

        logits = self.classifier(last_output)  # [B, 1]
        return logits

In [6]:
class ViTLSTM(nn.Module):
    def __init__(self, hidden_dim=256, lstm_layers=1, dropout=0.1):
        super(ViTLSTM, self).__init__()
        # Load pretrained ViT backbone without classification head
        self.vit = vit_b_16(pretrained=True)
        self.vit.heads = nn.Identity()  # remove ViT classifier head

        self.lstm = nn.LSTM(input_size=768,  # ViT base embedding size
                            hidden_size=hidden_dim,
                            num_layers=lstm_layers,
                            batch_first=True,
                            dropout=dropout if lstm_layers > 1 else 0)

        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        # x shape: [B, T, C, H, W] (sequence of images)
        B, T, C, H, W = x.shape
        x = x.view(B * T, C, H, W)  # flatten batch and time for ViT input
        vit_feats = self.vit(x)       # [B*T, 768]
        vit_feats = vit_feats.view(B, T, -1)  # reshape to sequence: [B, T, 768]

        lstm_out, _ = self.lstm(vit_feats)  # [B, T, hidden_dim]
        last_hidden = lstm_out[:, -1, :]    # take last time step [B, hidden_dim]

        out = self.classifier(last_hidden)  # [B, 1]
        return out.squeeze(1)  # [B]

In [7]:
class Simple3DCNN(nn.Module):
    def __init__(self):
        super(Simple3DCNN, self).__init__()
        self.conv1 = nn.Conv3d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool3d(2)
        self.conv2 = nn.Conv3d(16, 32, 3, padding=1)
        self.pool2 = nn.MaxPool3d(2)
        self.conv3 = nn.Conv3d(32, 64, 3, padding=1)
        self.pool3 = nn.AdaptiveAvgPool3d(1)  # Output size = (batch, 64, 1, 1, 1)
        self.fc = nn.Linear(64, 1)  # Binary classification

    def forward(self, x):
        x = F.relu(self.conv1(x))  # [B,16,T/1,H/1,W/1]
        x = self.pool1(x)          # Downsample
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = F.relu(self.conv3(x))
        x = self.pool3(x)          # [B,64,1,1,1]
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc(x)             # [B,1]
        return x


In [8]:
def train_model(
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    device,
    num_epochs=10,
    patience=5,
    best_model_path="best_model.pth",
    scheduler=None,
    use_early_stopping=True,
    is_vit=False,
    is_cnn=False
):
    best_val_f1 = 0.0
    epochs_no_improve = 0

    for epoch in range(num_epochs):
        # --- Training ---
        model.train()
        train_losses = []
        all_preds_train, all_labels_train = [], []

        for inputs, labels in train_loader:
            if is_cnn:
                inputs = inputs.permute(0, 2, 1, 3, 4).to(device)
            else:
                inputs = inputs.to(device)
            labels = labels.float().to(device)

            optimizer.zero_grad()
            if is_vit:
                outputs = model(inputs)
            else:
                outputs = model(inputs).squeeze(1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())
            preds = (torch.sigmoid(outputs) > 0.5).int()
            all_preds_train.extend(preds.cpu().numpy())
            all_labels_train.extend(labels.cpu().numpy())

        train_acc = accuracy_score(all_labels_train, all_preds_train)
        train_f1 = f1_score(all_labels_train, all_preds_train)

        # --- Validation ---
        model.eval()
        val_losses = []
        all_preds_val, all_labels_val = [], []

        with torch.no_grad():
            for inputs, labels in val_loader:
                if is_cnn:
                    inputs = inputs.permute(0, 2, 1, 3, 4).to(device)
                else:
                    inputs = inputs.to(device)
                labels = labels.float().to(device)

                if is_vit:
                    outputs = model(inputs)
                else:
                    outputs = model(inputs).squeeze(1)
                loss = criterion(outputs, labels)
                val_losses.append(loss.item())

                preds = (torch.sigmoid(outputs) > 0.5).int()
                all_preds_val.extend(preds.cpu().numpy())
                all_labels_val.extend(labels.cpu().numpy())

        val_acc = accuracy_score(all_labels_val, all_preds_val)
        val_f1 = f1_score(all_labels_val, all_preds_val)

        logger.info(f"Epoch {epoch+1}/{num_epochs}")
        logger.info(f"  Train Loss: {sum(train_losses)/len(train_losses):.4f} | Acc: {train_acc:.4f} | F1: {train_f1:.4f}")
        logger.info(f"  Val   Loss: {sum(val_losses)/len(val_losses):.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f}")

        # Scheduler step (if provided)
        if scheduler is not None:
            scheduler.step()  

        # --- Early Stopping ---
        if use_early_stopping:
            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                torch.save(model.state_dict(), best_model_path)
                logger.success(f"New best model saved at {best_model_path} with F1: {best_val_f1:.4f}")
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1
                if epochs_no_improve >= patience:
                    logger.warning("Early stopping triggered")
                    break
        else:
            # Always save best model regardless of patience
            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                torch.save(model.state_dict(), best_model_path)
                logger.success(f"New best model saved at {best_model_path} with F1: {best_val_f1:.4f}")

In [9]:
def evaluate_and_save_results(
        model,
        test_loader,
        device,
        model_name, 
        models_dir=MODELS_DIR,
        reports_dir=REPORTS_DIR,
        figures_dir=FIGURES_DIR,
        is_cnn=False
    ):
    
    os.makedirs(models_dir, exist_ok=True)
    os.makedirs(reports_dir, exist_ok=True)
    os.makedirs(figures_dir, exist_ok=True)

    best_model_path = os.path.join(models_dir, f"{model_name}.pth")
    model.load_state_dict(torch.load(best_model_path))
    model.eval()

    all_preds, all_labels = [], []

    with torch.no_grad():
        for inputs, labels in test_loader:
            if is_cnn:
                inputs = inputs.permute(0, 2, 1, 3, 4).to(device)
            else:
                inputs = inputs.to(device)
            labels = labels.float().to(device)
            outputs = model(inputs)
            preds = (torch.sigmoid(outputs) > 0.5).int()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # classification report
    report = classification_report(all_labels, all_preds, 
                                    target_names=["Class 0", "Class 1"], digits=4)
    report_path = os.path.join(reports_dir, f"{model_name}_report.txt")
    with open(report_path, "w") as f:
        f.write(report)
    logger.success(f"Classification report saved to {report_path}")

    # confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Class 0", "Class 1"])

    fig, ax = plt.subplots(figsize=(5, 5))
    disp.plot(ax=ax, cmap=plt.cm.Blues, values_format='d')
    plt.title(f"Confusion Matrix - {model_name}")
    cm_path = os.path.join(figures_dir, f"{model_name}_cm.png")
    plt.savefig(cm_path)
    plt.close()
    logger.success(f"Confusion matrix saved to {cm_path}")


In [10]:
# class distribution for weigths
labels = [label for _, label in train_dataset]
label_counts = Counter(labels)

neg = label_counts[0]
pos = label_counts[1]
pos_weight = torch.tensor([neg / pos]).to(device)

## ResNet18 + LSTM

In [11]:
model_name = "ResNet18LSTM"
os.makedirs(MODELS_DIR, exist_ok=True)
best_model_path = MODELS_DIR / f"{model_name}.pth"

resnet18 = ResNet18LSTM().to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = optim.AdamW(resnet18.parameters(), lr=1e-4, weight_decay=1e-4)

num_epochs = 10
patience = 5
best_f1 = 0.0
epochs_no_improve = 0

scheduler = StepLR(optimizer, step_size=5, gamma=0.1)
# scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)



In [12]:
train_model(
    resnet18,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    device,
    num_epochs,
    patience,
    best_model_path,
    scheduler,
    use_early_stopping=True
)

[32m2025-08-11 15:28:43.835[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m76[0m - [1mEpoch 1/10[0m
[32m2025-08-11 15:28:43.836[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m77[0m - [1m  Train Loss: 1.2855 | Acc: 0.6866 | F1: 0.1569[0m
[32m2025-08-11 15:28:43.836[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m78[0m - [1m  Val   Loss: 1.1656 | Acc: 0.7619 | F1: 0.2222[0m
[32m2025-08-11 15:28:43.880[0m | [32m[1mSUCCESS [0m | [36m__main__[0m:[36mtrain_model[0m:[36m89[0m - [32m[1mNew best model saved at C:\Users\Molinari\Desktop\embryo-project\models\ResNet18LSTM.pth with F1: 0.2222[0m
[32m2025-08-11 15:29:32.796[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m76[0m - [1mEpoch 2/10[0m
[32m2025-08-11 15:29:32.796[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m77[0m - [1m  Train Loss: 0.8417 | Acc: 0.8367 | F1: 0.3333[0m
[32m2025-08-11 15:29:32.796

In [13]:
evaluate_and_save_results(
    resnet18,
    test_loader,
    device,
    model_name
)

[32m2025-08-11 15:34:54.460[0m | [32m[1mSUCCESS [0m | [36m__main__[0m:[36mevaluate_and_save_results[0m:[36m40[0m - [32m[1mClassification report saved to C:\Users\Molinari\Desktop\embryo-project\reports\ResNet18LSTM_report.txt[0m
[32m2025-08-11 15:34:54.514[0m | [32m[1mSUCCESS [0m | [36m__main__[0m:[36mevaluate_and_save_results[0m:[36m52[0m - [32m[1mConfusion matrix saved to C:\Users\Molinari\Desktop\embryo-project\reports\figures\ResNet18LSTM_cm.png[0m


## ViT + LSTM

In [14]:
model_name = "ViTLSTM"
os.makedirs(MODELS_DIR, exist_ok=True)
best_model_path = MODELS_DIR / f"{model_name}.pth"

vitlstm = ViTLSTM(hidden_dim=256, lstm_layers=1, dropout=0.1).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = optim.AdamW(vitlstm.parameters(), lr=1e-4, weight_decay=1e-4)

num_epochs = 10
patience = 5
best_f1 = 0.0
epochs_no_improve = 0

scheduler = StepLR(optimizer, step_size=5, gamma=0.1)
# scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)



In [15]:
train_model(
    vitlstm,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    device,
    num_epochs,
    patience,
    best_model_path,
    scheduler,
    use_early_stopping=True,
    is_vit=True
)

[32m2025-08-11 15:43:12.710[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m76[0m - [1mEpoch 1/10[0m
[32m2025-08-11 15:43:12.711[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m77[0m - [1m  Train Loss: 1.3818 | Acc: 0.5991 | F1: 0.0924[0m
[32m2025-08-11 15:43:12.711[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m78[0m - [1m  Val   Loss: 1.4288 | Acc: 0.0612 | F1: 0.1154[0m
[32m2025-08-11 15:43:12.990[0m | [32m[1mSUCCESS [0m | [36m__main__[0m:[36mtrain_model[0m:[36m89[0m - [32m[1mNew best model saved at C:\Users\Molinari\Desktop\embryo-project\models\ViTLSTM.pth with F1: 0.1154[0m
[32m2025-08-11 15:50:53.012[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m76[0m - [1mEpoch 2/10[0m
[32m2025-08-11 15:50:53.012[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m77[0m - [1m  Train Loss: 1.3538 | Acc: 0.5481 | F1: 0.0988[0m
[32m2025-08-11 15:50:53.012[0m |

In [17]:
evaluate_and_save_results(
    vitlstm,
    test_loader,
    device,
    model_name
)

[32m2025-08-11 16:26:09.981[0m | [32m[1mSUCCESS [0m | [36m__main__[0m:[36mevaluate_and_save_results[0m:[36m40[0m - [32m[1mClassification report saved to C:\Users\Molinari\Desktop\embryo-project\reports\ViTLSTM_report.txt[0m
[32m2025-08-11 16:26:10.044[0m | [32m[1mSUCCESS [0m | [36m__main__[0m:[36mevaluate_and_save_results[0m:[36m52[0m - [32m[1mConfusion matrix saved to C:\Users\Molinari\Desktop\embryo-project\reports\figures\ViTLSTM_cm.png[0m


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


## 3DCNN

In [18]:
model_name = "3DCNN"
os.makedirs(MODELS_DIR, exist_ok=True)
best_model_path = MODELS_DIR / f"{model_name}.pth"

cnn = Simple3DCNN().to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = optim.AdamW(cnn.parameters(), lr=1e-4, weight_decay=1e-4)

num_epochs = 10
patience = 5
best_f1 = 0.0
epochs_no_improve = 0

scheduler = StepLR(optimizer, step_size=5, gamma=0.1)
# scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)

In [19]:
train_model(
    cnn,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    device,
    num_epochs,
    patience,
    best_model_path,
    scheduler,
    use_early_stopping=True,
    is_cnn=True
)

[32m2025-08-11 16:27:09.560[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m76[0m - [1mEpoch 1/10[0m
[32m2025-08-11 16:27:09.561[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m77[0m - [1m  Train Loss: 1.3123 | Acc: 0.1283 | F1: 0.0912[0m
[32m2025-08-11 16:27:09.561[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m78[0m - [1m  Val   Loss: 1.3315 | Acc: 0.3197 | F1: 0.1228[0m
[32m2025-08-11 16:27:09.564[0m | [32m[1mSUCCESS [0m | [36m__main__[0m:[36mtrain_model[0m:[36m89[0m - [32m[1mNew best model saved at C:\Users\Molinari\Desktop\embryo-project\models\3DCNN.pth with F1: 0.1228[0m
[32m2025-08-11 16:27:55.751[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m76[0m - [1mEpoch 2/10[0m
[32m2025-08-11 16:27:55.752[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m77[0m - [1m  Train Loss: 1.3044 | Acc: 0.6516 | F1: 0.1115[0m
[32m2025-08-11 16:27:55.752[0m | 

In [20]:
evaluate_and_save_results(
    cnn,
    test_loader,
    device,
    model_name,
    is_cnn=True
)

[32m2025-08-11 16:34:05.502[0m | [32m[1mSUCCESS [0m | [36m__main__[0m:[36mevaluate_and_save_results[0m:[36m40[0m - [32m[1mClassification report saved to C:\Users\Molinari\Desktop\embryo-project\reports\3DCNN_report.txt[0m
[32m2025-08-11 16:34:05.545[0m | [32m[1mSUCCESS [0m | [36m__main__[0m:[36mevaluate_and_save_results[0m:[36m52[0m - [32m[1mConfusion matrix saved to C:\Users\Molinari\Desktop\embryo-project\reports\figures\3DCNN_cm.png[0m
