In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

# Step 1: Define transforms (resize + normalize)
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


train_data = datasets.ImageFolder(root="dataset/train", transform=transform)
val_data = datasets.ImageFolder(root="dataset/val", transform=transform)

train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
val_loader = DataLoader(val_data, batch_size=16, shuffle=False)

# Step 3: Define a simple CNN (or use pretrained ResNet18)
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)  # chart vs nonchart

# Step 4: Loss & optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Step 5: Training loop
for epoch in range(5):
    model.train()
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/5], Loss: {loss.item():.4f}")

# Save model
torch.save(model.state_dict(), "chart_detector.pth")




Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to C:\Users\MAINAK/.cache\torch\hub\checkpoints\resnet18-f37072fd.pth


100.0%


Epoch [1/5], Loss: 0.0059
Epoch [2/5], Loss: 0.0004
Epoch [3/5], Loss: 0.0104
Epoch [4/5], Loss: 0.0000
Epoch [5/5], Loss: 0.0059
