In [45]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
from ca_utils import ResNet, BasicBlock
from torch.optim.lr_scheduler import OneCycleLR

In [46]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [47]:
train_dataset = datasets.ImageFolder(root='Desktop/train', transform=transform)
val_dataset = datasets.ImageFolder(root='Desktop/val', transform=transform)

In [48]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64,pin_memory=True, num_workers=5, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False)

In [49]:
model = ResNet(BasicBlock, [1,1,1], num_classes=10)

In [50]:
def train_cnn(model, train_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=0.001)
    epochs = 50
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.1, steps_per_epoch=len(train_loader), epochs=epochs)
    
    for epoch in range(epochs):
        model.train()
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}"
                  f" ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")
        
            
            

        scheduler.step()

        torch.save(model.state_dict(), 'Desktop/weights_resnet (1).pth')
       # print("New best validation accuracy model saved.")

    print("Finished training and saved model weights")



In [51]:
train_cnn(model,train_loader)

Finished training and saved model weights
