In [None]:
# 📌 train_model.ipynb

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import time

# ✅ Setup
data_dir = '../data/raw'  # Adjust if needed
batch_size = 16
num_epochs = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 🔁 Transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

# 📂 Load Dataset
dataset = datasets.ImageFolder(data_dir, transform=transform)
class_names = dataset.classes
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 🧠 Model Setup
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))
model = model.to(device)

# ⚙️ Loss & Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 🏋️‍♂️ Training Loop
print("Training...")
start = time.time()
for epoch in range(num_epochs):
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss:.4f}, Accuracy: {100 * correct / total:.2f}%")

end = time.time()
print(f"\n✅ Finished Training in {(end - start)/60:.2f} minutes.")

# 💾 Save Model
os.makedirs("../model", exist_ok=True)
torch.save(model.state_dict(), "../model/model.pth")
print("Model saved to model/model.pth")

# ✅ Save class names
import json
with open("../model/class_names.json", "w") as f:
    json.dump(class_names, f)
print("Class labels saved.")




Training...
