<a href="https://colab.research.google.com/github/jmand626/EXP-tracker-pro/blob/main/ExperimentTracker.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# 1. Imports and Setup
import os
import sys
import torch
import torchvision
import shutil
from pathlib import Path

# Mount Drive
from google.colab import drive
drive.mount('/content/drive')

# Clone Repo
repo_url = "https://github.com/jmand626/PyTorchMLEngine-Custom-Dataset-Project.git"
repo_name = "PyTorchMLEngine-Custom-Dataset-Project"

if not os.path.exists(repo_name):
    print(f"Cloning {repo_url}...")
    !git clone {repo_url}

# Add to sys.path
os.chdir(repo_name)
sys.path.append(os.getcwd())
print(f"Current working directory: {os.getcwd()}")

# Install TensorBoard and Torchinfo
!pip install -q torchinfo tensorboard

# Define Paths
gdrive_train_dir = "/content/drive/MyDrive/data/fgvc-aircraft-2013b/train"
gdrive_test_dir = "/content/drive/MyDrive/data/fgvc-aircraft-2013b/test"
local_train_dir = "/content/train"
local_test_dir = "/content/test"

# COPY DATA TO LOCAL VM (Crucial for speed)
print("Copying dataset from Google Drive to local VM...")
if not os.path.exists(local_train_dir):
    shutil.copytree(gdrive_train_dir, local_train_dir)
    print("Train set copied.")
else:
    print("Train set already exists locally.")

if not os.path.exists(local_test_dir):
    shutil.copytree(gdrive_test_dir, local_test_dir)
    print("Test set copied.")
else:
    print("Test set already exists locally.")

# Setup Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

Mounted at /content/drive
Cloning https://github.com/jmand626/PyTorchMLEngine-Custom-Dataset-Project.git...
Cloning into 'PyTorchMLEngine-Custom-Dataset-Project'...
remote: Enumerating objects: 24, done.[K
remote: Counting objects: 100% (24/24), done.[K
remote: Compressing objects: 100% (21/21), done.[K
remote: Total 24 (delta 6), reused 0 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (24/24), 21.93 KiB | 5.48 MiB/s, done.
Resolving deltas: 100% (6/6), done.
Current working directory: /content/PyTorchMLEngine-Custom-Dataset-Project
Copying dataset from Google Drive to local VM...
Train set copied.
Test set copied.
Device: cpu


In [2]:
import torchvision.transforms as transforms
import setup_dataholders
import importlib
importlib.reload(setup_dataholders) # Ensure we have the latest version

# Define Transforms (Standard ImageNet Normalization)
manual_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Create DataLoaders
# Note: We are using workers=0 to avoid the deadlock issue you faced before
train_dataloader, test_dataloader, class_names = setup_dataholders.create_dataloaders(
    train_directory=local_train_dir,
    test_directory=local_test_dir,
    data_transforms=manual_transforms,
    batch_size=32,
    workers=0
)

print(f"Number of classes: {len(class_names)}")
print(f"Classes: {class_names[:10]}...") # Print first 10

Number of classes: 100
Classes: ['707_320', '727_200', '737_200', '737_300', '737_400', '737_500', '737_600', '737_700', '737_800', '737_900']...


In [3]:
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

def create_writer(experiment_name: str,
                  model_name: str,
                  extra: str=None) -> SummaryWriter:
    """
    Creates a torch.utils.tensorboard.writer.SummaryWriter() instance saving to a specific log_dir.
    """
    # Get timestamp of current date
    timestamp = datetime.now().strftime("%Y-%m-%d")

    if extra:
        log_dir = os.path.join("runs", timestamp, experiment_name, model_name, extra)
    else:
        log_dir = os.path.join("runs", timestamp, experiment_name, model_name)

    print(f"[INFO] Created SummaryWriter, saving to: {log_dir}...")
    return SummaryWriter(log_dir=log_dir)

In [4]:
from typing import Dict, List
from tqdm.auto import tqdm
import torch.nn as nn
from model_backbone import train_step, run_test_step # Import steps from your file

def train_with_tracking(model: torch.nn.Module,
          train_dataloader: torch.utils.data.DataLoader,
          test_dataloader: torch.utils.data.DataLoader,
          optimizer: torch.optim.Optimizer,
          loss_fn: torch.nn.Module,
          epochs: int,
          device: torch.device,
          writer: SummaryWriter) -> Dict[str, List]:

    results = {"train_loss": [], "train_acc": [], "test_loss": [], "test_acc": []}

    for epoch in tqdm(range(epochs)):
        # 1. Train Step
        train_loss, train_acc = train_step(model=model,
                                           dataloader=train_dataloader,
                                           loss_fn=loss_fn,
                                           optimizer=optimizer,
                                           device=device)
        # 2. Test Step
        test_loss, test_acc = run_test_step(model=model,
                                        dataloader=test_dataloader,
                                        loss_fn=loss_fn,
                                        device=device)

        print(f"Epoch: {epoch+1} | train_loss: {train_loss:.4f} | train_acc: {train_acc:.4f} | test_loss: {test_loss:.4f} | test_acc: {test_acc:.4f}")

        # 3. Log to TensorBoard
        if writer:
            writer.add_scalars(main_tag="Loss",
                               tag_scalar_dict={"train_loss": train_loss, "test_loss": test_loss},
                               global_step=epoch)
            writer.add_scalars(main_tag="Accuracy",
                               tag_scalar_dict={"train_acc": train_acc, "test_acc": test_acc},
                               global_step=epoch)
            # Close the writer
            writer.close()

        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["test_loss"].append(test_loss)
        results["test_acc"].append(test_acc)

    return results