In [None]:
# %%
# Install PyTorch Lightning and timm
!pip install pytorch-lightning timm

print("Dependencies installed. Please re-run the previous code block (Code Block 1) now.")

Collecting pytorch-lightning
  Downloading pytorch_lightning-2.5.6-py3-none-any.whl.metadata (20 kB)
Collecting torchmetrics>0.7.0 (from pytorch-lightning)
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch-lightning)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading pytorch_lightning-2.5.6-py3-none-any.whl (831 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m831.6/831.6 kB[0m [31m18.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Downloading torchmetrics-1.8.2-py3-none-any.whl (983 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m35.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lightning-utilities, torchmetrics, pytorch-lightning
Successfully installed lightning-utilities-0.15.2 pytorch-lightning-2.5.6 torchmetrics-1.8.2
Dependencies installed.

In [None]:
# %%
import torch
import torch.nn as nn
import pytorch_lightning as pl
import timm
import json
import os
from pytorch_lightning.callbacks import ModelCheckpoint
from google.colab import drive
# from project_data_module import YourDataModule # Keep this commented for now

# --- 1. Mount Google Drive ---
PROJECT_ROOT = '/content/drive/MyDrive/Hybrid_Project'
print(f"Project root set to: {PROJECT_ROOT}")

# Mount Drive
drive.mount('/content/drive')

# Create necessary directories
os.makedirs(os.path.join(PROJECT_ROOT, 'models'), exist_ok=True)
os.makedirs(os.path.join(PROJECT_ROOT, 'results'), exist_ok=True)

print("Google Drive mounted and directories checked.")

Project root set to: /content/drive/MyDrive/Hybrid_Project
Mounted at /content/drive
Google Drive mounted and directories checked.


In [None]:
# Add this line at the absolute top of the cell (no preceding spaces or lines)
%%writefile /content/drive/MyDrive/Hybrid_Project/project_data_module.py

import os
import torch
import logging
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets
from torchvision.datasets.folder import default_loader
from PIL import UnidentifiedImageError
# Note: PIL/Image is implicitly used by default_loader

# Configure logging to see which files are skipped
logging.basicConfig(level=logging.WARNING, format='%(levelname)s: %(message)s')

# --- 1. Robust Dataset Wrapper ---
class RobustImageFolder(Dataset):
    """A wrapper for ImageFolder that skips corrupted images."""
    def __init__(self, dataset):
        self.dataset = dataset
        self.loader = default_loader
        self.transform = dataset.transform

    def __getitem__(self, index):
        path, target = self.dataset.samples[index]

        try:
            # Load the image using the default loader
            sample = self.loader(path)

            if self.transform is not None:
                sample = self.transform(sample)

            return sample, target

        except (UnidentifiedImageError, OSError) as e:
            # If the image is corrupted or cannot be read, log and get a random new index
            logging.warning(f"Skipping corrupted file: {path}")

            # Get a random new index to fetch a valid sample instead
            new_index = torch.randint(0, len(self), (1,)).item()
            return self.__getitem__(new_index) # Recursively call to get a valid sample

    def __len__(self):
        return len(self.dataset)


# --- 2. Transformations ---
IMAGE_TRANSFORM = transforms.Compose([
    transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


# --- 3. HybridDataModule ---
class HybridDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str, batch_size: int = 32):
        super().__init__()
        # Use the passed data_dir argument
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = IMAGE_TRANSFORM

        self.train_dir = os.path.join(self.data_dir, 'train')
        self.val_dir = os.path.join(self.data_dir, 'valid')
        self.test_dir = os.path.join(self.data_dir, 'test')


    def setup(self, stage=None):
        try:
            print("Loading training dataset (Robust)...")
            # Load base ImageFolder, then wrap it with the RobustImageFolder
            base_train_dataset = datasets.ImageFolder(root=self.train_dir, transform=self.transform)
            self.train_dataset = RobustImageFolder(base_train_dataset)

            print("Loading validation dataset (Robust)...")
            base_val_dataset = datasets.ImageFolder(root=self.val_dir, transform=self.transform)
            self.val_dataset = RobustImageFolder(base_val_dataset)

            base_test_dataset = datasets.ImageFolder(root=self.test_dir, transform=self.transform)
            self.test_dataset = RobustImageFolder(base_test_dataset)

            self.num_classes = len(base_train_dataset.classes)
            print(f"Dataset loaded successfully. Found {self.num_classes} classes.")

        except Exception as e:
            print(f"⚠️ Warning: Failed to load REAL data (Exception during setup): {e}. Creating DUMMY data.")
            self.num_classes = 3
            # DUMMY DATA CREATION (This will only run if an exception occurs)
            self.train_dataset = [(torch.randn(3, 224, 224), torch.randint(0, self.num_classes, (1,)).item()) for _ in range(100)]
            self.val_dataset = [(torch.randn(3, 224, 224), torch.randint(0, self.num_classes, (1,)).item()) for _ in range(20)]
            self.test_dataset = [(torch.randn(3, 224, 224), torch.randint(0, self.num_classes, (1,)).item()) for _ in range(20)]


    def train_dataloader(self):
        # FIX: Setting num_workers=0 to prevent multiprocessing crash
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=0)

    def val_dataloader(self):
        # FIX: Setting num_workers=0
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=0)

    def test_dataloader(self):
        # FIX: Setting num_workers=0
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=0)



Overwriting /content/drive/MyDrive/Hybrid_Project/project_data_module.py


In [None]:
# %%
# Fix for ModuleNotFoundError when importing local files from Drive
import sys

# PROJECT_ROOT was defined in Code Block 1
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)
    print(f"Added {PROJECT_ROOT} to sys.path for local module imports.")
else:
    print(f"{PROJECT_ROOT} already in sys.path.")

/content/drive/MyDrive/Hybrid_Project already in sys.path.


In [None]:
# %%
# Force a reload of the project_data_module file to ensure the RobustImageFolder fix is active
import importlib
try:
    import project_data_module
    importlib.reload(project_data_module)
    print("Forced reload of project_data_module.")
except ImportError:
    print("project_data_module not found, continuing...")

Forced reload of project_data_module.


In [None]:
# %%
# -----------------------------------------------------------
# 2. Load DataModule
# -----------------------------------------------------------

import torchmetrics

# Assuming 'project_data_module.py' is in the PROJECT_ROOT folder
from project_data_module import HybridDataModule
import torch
import torch.nn as nn
import pytorch_lightning as pl
import timm

# Since 'train', 'test', and 'valid' are directly in 'data/'
DATA_PATH = os.path.join(PROJECT_ROOT, 'data')

# Initialize the DataModule instance
dm = HybridDataModule(
    data_dir=DATA_PATH,
    batch_size=32,
    # image_size=224 is handled by your internal transforms
)

# Run setup to prepare data splits (this will determine NUM_CLASSES)
dm.setup('fit')

# Get the confirmed number of classes
NUM_CLASSES = dm.num_classes

print(f"Data Path defined: {DATA_PATH}")
print(f"Number of Classes obtained from DataModule: {NUM_CLASSES}")


# %%
# -----------------------------------------------------------
# 3. Define ViT Model (Frozen Backbone)
# -----------------------------------------------------------

class ViTModel(pl.LightningModule):
    def __init__(self, num_classes: int):
        super().__init__()
        self.save_hyperparameters()
        self.num_classes = num_classes

        # Load ViT-Small (ViT-Small/16) from timm
        self.backbone = timm.create_model(
            'vit_small_patch16_224',
            pretrained=True,
            num_classes=0 # Set to 0 to remove the default classification head
        )

        # --- Freeze Backbone ---
        for param in self.backbone.parameters():
            param.requires_grad = False

        # Feature dimension for vit_small_patch16_224 is 768
        feature_dim = self.backbone.num_features

        # --- Add Classifier Head ---
        self.classifier = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

        self.criterion = nn.CrossEntropyLoss()

        # Metrics for testing (Task 3)
        self.test_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
        # We use average='none' to get the class-wise results needed for the log
        self.test_f1 = torchmetrics.F1Score(task='multiclass', num_classes=num_classes, average='none')

    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

    # --- Training Step ---
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    # --- Test Step (Required for Task 3) ---
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)

        self.test_accuracy.update(preds, y)
        self.test_f1.update(preds, y)
        self.log('test_loss', loss, on_step=False, on_epoch=True)

    def on_test_epoch_end(self):
        # Calculate and log final metrics at the end of the test epoch
        self.log('accuracy', self.test_accuracy.compute())

        # Calculate class-wise results
        class_f1 = self.test_f1.compute()
        for i in range(self.num_classes):
             self.log(f'class_{i}_f1', class_f1[i], on_epoch=True)

        # Reset metrics for next use
        self.test_accuracy.reset()
        self.test_f1.reset()

    def configure_optimizers(self):
        # Only optimizing the un-frozen parameters (the classifier head)
        optimizer = torch.optim.Adam(self.classifier.parameters(), lr=1e-3)
        return optimizer

# Initialize the model
vit_model = ViTModel(num_classes=NUM_CLASSES)

print("-" * 30)
print(f"ViT Model initialized with {NUM_CLASSES} classes.")
print(f"Backbone requires_grad status (should be False): {vit_model.backbone.parameters().__next__().requires_grad}")
print("Classifier requires_grad status (should be True):", vit_model.classifier.parameters().__next__().requires_grad)
print("-" * 30)

Loading training dataset (Robust)...
Loading validation dataset (Robust)...
Dataset loaded successfully. Found 3 classes.
Data Path defined: /content/drive/MyDrive/Hybrid_Project/data
Number of Classes obtained from DataModule: 3
------------------------------
ViT Model initialized with 3 classes.
Backbone requires_grad status (should be False): False
Classifier requires_grad status (should be True): True
------------------------------


In [25]:
# %%
# -----------------------------------------------------------
# 4. Initialize Trainer & Run Training (5–7 epochs)
# -----------------------------------------------------------

from pytorch_lightning.callbacks import ModelCheckpoint
import torch

# Define the path for saving model weights
MODEL_SAVE_DIR = os.path.join(PROJECT_ROOT, 'models')
MODEL_FILENAME = 'vit_day3' # Base name for the file

# Define a checkpoint callback to save the model weights (Task 2)
# The actual saved file will be vit_day3-vX.ckpt (PyTorch Lightning default)
checkpoint_callback = ModelCheckpoint(
    dirpath=MODEL_SAVE_DIR,
    filename=MODEL_FILENAME,
    monitor='train_loss',
    mode='min',
    save_last=True,
    save_top_k=1, # Save only the best model
    verbose=True
)

# Initialize Trainer
trainer = pl.Trainer(
    max_epochs=7,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    callbacks=[checkpoint_callback],
    log_every_n_steps=50,
    # Use 16-mixed precision for faster training on GPU
    precision='16-mixed' if torch.cuda.is_available() else '32',
)

print(f"Trainer initialized for max {trainer.max_epochs} epochs.")
print("Starting ViT training (Task 1)...")

# --- Run TRAINING for 5–7 epochs (main part of Task 1) ---
trainer.fit(vit_model, datamodule=dm)

print("Training finished.")

# Task 2: Confirmation of saved weights
print(f"Model weights saved to Drive. Best checkpoint path: {checkpoint_callback.best_model_path}")

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores


Trainer initialized for max 7 epochs.
Starting ViT training (Task 1)...
Loading training dataset (Robust)...


INFO:pytorch_lightning.callbacks.model_summary:
  | Name          | Type               | Params | Mode 
-------------------------------------------------------------
0 | backbone      | VisionTransformer  | 21.7 M | train
1 | classifier    | Sequential         | 198 K  | train
2 | criterion     | CrossEntropyLoss   | 0      | train
3 | test_accuracy | MulticlassAccuracy | 0      | train
4 | test_f1       | MulticlassF1Score  | 0      | train
-------------------------------------------------------------
198 K     Trainable params
21.7 M    Non-trainable params
21.9 M    Total params
87.457    Total estimated model params size (MB)
284       Modules in train mode
0         Modules in eval mode


Loading validation dataset (Robust)...
Dataset loaded successfully. Found 3 classes.


Training: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 0, global step 529: 'train_loss' reached 0.75283 (best 0.75283), saving model to '/content/drive/MyDrive/Hybrid_Project/models/vit_day3-v1.ckpt' as top 1
INFO:pytorch_lightning.utilities.rank_zero:Epoch 1, global step 1058: 'train_loss' reached 0.66312 (best 0.66312), saving model to '/content/drive/MyDrive/Hybrid_Project/models/vit_day3-v1.ckpt' as top 1
INFO:pytorch_lightning.utilities.rank_zero:Epoch 2, global step 1587: 'train_loss' reached 0.63756 (best 0.63756), saving model to '/content/drive/MyDrive/Hybrid_Project/models/vit_day3-v1.ckpt' as top 1
INFO:pytorch_lightning.utilities.rank_zero:Epoch 3, global step 2116: 'train_loss' reached 0.62569 (best 0.62569), saving model to '/content/drive/MyDrive/Hybrid_Project/models/vit_day3-v1.ckpt' as top 1
INFO:pytorch_lightning.utilities.rank_zero:Epoch 4, global step 2645: 'train_loss' reached 0.61514 (best 0.61514), saving model to '/content/drive/MyDrive/Hybrid_Project/models/vit_day3

Training finished.
Model weights saved to Drive. Best checkpoint path: /content/drive/MyDrive/Hybrid_Project/models/vit_day3-v1.ckpt


In [26]:
# %%
# -----------------------------------------------------------
# 5. Run Test Phase (Task 3)
# -----------------------------------------------------------
import json
import os

print("Starting ViT Test Phase...")

# The best model path determined by the checkpoint callback is used.
BEST_MODEL_PATH = '/content/drive/MyDrive/Hybrid_Project/models/vit_day3-v1.ckpt' # Explicitly set based on output

# Load the best model from the checkpoint
# The ViTModel class definition must be available (which it is)
vit_model_test = ViTModel.load_from_checkpoint(
    BEST_MODEL_PATH,
    num_classes=NUM_CLASSES # NUM_CLASSES should still be 3 from Code Block 2
)

# Run the test loop using the same DataModule (dm)
test_results = trainer.test(vit_model_test, datamodule=dm)

print("Test phase complete.")
print(f"Test Results: {test_results}")

# %%
# -----------------------------------------------------------
# 6. Save Test Results (Task 3)
# -----------------------------------------------------------

RESULTS_FILE_PATH = os.path.join(PROJECT_ROOT, 'results', 'test_results_vit_day3.json')

# The test_results is a list of dicts; we save the first element (the main results)
# The dictionary contains: test_loss, accuracy, and class_0_f1, class_1_f1, etc.
with open(RESULTS_FILE_PATH, 'w') as f:
    # We save the results from the first element of the list
    json.dump(test_results[0], f, indent=4)

print(f"Test results saved to: {RESULTS_FILE_PATH}")

Starting ViT Test Phase...
Loading training dataset (Robust)...
Loading validation dataset (Robust)...
Dataset loaded successfully. Found 3 classes.


Testing: |          | 0/? [00:00<?, ?it/s]

Test phase complete.
Test Results: [{'test_loss': 0.5377506017684937, 'accuracy': 0.8147208094596863, 'class_0_f1': 0.8989361524581909, 'class_1_f1': 0.0, 'class_2_f1': 0.7692307829856873}]
Test results saved to: /content/drive/MyDrive/Hybrid_Project/results/test_results_vit_day3.json
