In [1]:
import numpy as np
import torch
import tifffile
from pathlib import Path
import time
import os
import torch.nn as nn
import torch.optim as optim

PROJECT_ROOT = Path("/mnt/home/dchhantyal/3d-cnn-classification")
import sys
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
sys.path.append(str(PROJECT_ROOT))
from model.model import ConvRNN, Config, resize_volume, NucleusDataset, DataLoader, RandomAugmentation3D

  setattr(self, word, getattr(machar, word).flat[0])
  return self._float_to_str(self.smallest_subnormal)
  setattr(self, word, getattr(machar, word).flat[0])
  return self._float_to_str(self.smallest_subnormal)


In [2]:
DATA_ROOT_DIR = (
    "/mnt/home/dchhantyal/3d-cnn-classification/data/nuclei_state_dataset"  #
)

config = Config() 

In [3]:
# 1. Initialize Model, Loss, and Optimizer
model = ConvRNN(num_classes=config.num_classes).to(config.device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

In [4]:
# 2. Prepare DataLoaders
full_dataset = NucleusDataset(root_dir=DATA_ROOT_DIR)


In [5]:
 # 1. Create a dataset for training WITH the augmentation transform
train_full_dataset = NucleusDataset(
    root_dir=DATA_ROOT_DIR, transform=RandomAugmentation3D()
)

# 2. Create a second dataset for validation WITHOUT the transform
val_full_dataset = NucleusDataset(root_dir=DATA_ROOT_DIR, transform=None)

# 3. Perform the stratified split on indices
labels = [sample[1] for sample in train_full_dataset.samples]
indices = list(range(len(train_full_dataset)))
train_indices, val_indices = train_test_split(
    indices, test_size=0.2, random_state=42, stratify=labels
)

# 4. Create Subsets using the correct dataset instance for each
train_dataset = Subset(train_full_dataset, train_indices)
val_dataset = Subset(val_full_dataset, val_indices)
# --- END OF MODIFICATION ---

train_loader = DataLoader(
    train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=2
)
val_loader = DataLoader(
    val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=2
)

print(f"Found {len(full_dataset)} total samples.")
print(
    f"Training on {len(train_dataset)} samples, validating on {len(val_dataset)} samples."
)
print(
    f"Model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters."
)

Found 605 total samples.
Training on 484 samples, validating on 121 samples.
Model has 97,731 trainable parameters.




In [None]:
# 3. Training
for epoch in range(config.num_epochs):  
    start_time = time.time()
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0

    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(config.device), labels.to(
            config.device
        )

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()

    avg_train_loss = train_loss / len(train_loader)
    train_accuracy = 100 * train_correct / train_total

    # 4. Validation
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(config.device), labels.to(
                config.device
            )  # Fixed: Config.DEVICE -> config.device
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = 100 * val_correct / val_total

    epoch_duration = time.time() - start_time

    print(
        f"Epoch [{epoch+1}/{config.num_epochs}] | " 
        f"Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.2f}% | "
        f"Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.2f}% | "
        f"Duration: {epoch_duration:.2f}s"
    )


Epoch [1/300] | Train Loss: 1.1040, Train Acc: 32.02% | Val Loss: 1.1059, Val Acc: 32.23% | Duration: 169.21s
Epoch [2/300] | Train Loss: 1.0997, Train Acc: 32.02% | Val Loss: 1.0978, Val Acc: 33.88% | Duration: 168.37s
Epoch [3/300] | Train Loss: 1.0963, Train Acc: 32.64% | Val Loss: 1.0924, Val Acc: 33.06% | Duration: 165.94s
Epoch [4/300] | Train Loss: 1.0935, Train Acc: 34.30% | Val Loss: 1.0862, Val Acc: 33.88% | Duration: 166.00s
Epoch [5/300] | Train Loss: 1.0935, Train Acc: 35.33% | Val Loss: 1.0840, Val Acc: 36.36% | Duration: 165.76s
Epoch [6/300] | Train Loss: 1.0888, Train Acc: 37.60% | Val Loss: 1.0766, Val Acc: 39.67% | Duration: 165.58s
Epoch [7/300] | Train Loss: 1.0865, Train Acc: 40.08% | Val Loss: 1.0741, Val Acc: 40.50% | Duration: 166.19s
Epoch [8/300] | Train Loss: 1.0881, Train Acc: 38.84% | Val Loss: 1.0734, Val Acc: 41.32% | Duration: 166.10s
Epoch [9/300] | Train Loss: 1.0871, Train Acc: 38.22% | Val Loss: 1.0672, Val Acc: 43.80% | Duration: 167.50s
Epoch [10/

In [None]:

# 6. Final model save
final_model_path = "raw_masked_final_model.pth"
torch.save(model.state_dict(), final_model_path)
print(f"Final model saved to {final_model_path}")

print("\nTraining finished.")