In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
from torch import nn
import os

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [None]:
device

device(type='cuda', index=0)

In [None]:
# Step 1: Load the pre-trained model
model = torchvision.models.mobilenet_v2(pretrained=True).to(device)

#Step 2: Freeze the model parameters
for param in model.parameters():
    param.requires_grad = False

In [None]:
model.classifier = torch.nn.Linear(in_features=1280, out_features=4).to(device)

In [None]:

# Step 4: Load the images and create dataloaders
# Assuming your dataset has 3 classes and is organized into
# folders with class names:
data_transform = transforms.Compose([
    transforms.Resize(128),
    #transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    transforms.RandomVerticalFlip(0.1),
    transforms.RandomRotation((-0.15,0.15))
])

val_data_transform = transforms.Compose([
    transforms.Resize(128),
    #transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),

])


# Load the training data
train_dataset = torchvision.datasets.ImageFolder(root='train', transform=data_transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=512, shuffle=True,pin_memory = True)

# Load the validation data
val_dataset = torchvision.datasets.ImageFolder(root='valid', transform=val_data_transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=512, shuffle=False,pin_memory = True)

In [None]:
# Step 5: Choose an optimizer and loss function
optimizer = torch.optim.Adam(model.parameters())
loss_fn = torch.nn.CrossEntropyLoss()


num_epochs=300
best_loss = float('inf')

model = model.to(device)

In [None]:
for epoch in range(num_epochs):
    # Training
    model.train()
    train_loss, train_acc = 0.0, 0.0

    for i, (inputs, labels) in tqdm(enumerate(train_loader)):
        optimizer.zero_grad()
        outputs = model(inputs.to(device))
        loss = loss_fn(outputs, labels.to(device))
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * inputs.size(0)
        _, predictions = torch.max(outputs, 1)
        train_acc += (predictions == labels.to(device)).sum().item()

    # Validation
    model.eval()
    val_loss, val_acc = 0.0, 0.0
    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = model(inputs.to(device))
            loss = loss_fn(outputs, labels.to(device))
            val_loss += loss.item() * inputs.size(0)
            _, predictions = torch.max(outputs, 1)
            val_acc += (predictions == labels.to(device)).sum().item()

    # Calculate average losses and accuracy
    train_loss = train_loss / len(train_loader.dataset)
    train_acc = train_acc / len(train_loader.dataset)
    val_loss = val_loss / len(val_loader.dataset)
    val_acc = val_acc / len(val_loader.dataset)

    # Print training and validation results
    print(f'Epoch {epoch + 1}: train_loss = {train_loss:.4f}, train_acc = {train_acc:.4f}, val_loss = {val_loss:.4f}, val_acc = {val_acc:.4f}')

    # Save model every epoch
    torch.save(model.state_dict(), f'best_models/model_{epoch + 1}.pt')
    
    # Save best model
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), 'best_models/best_model.pt')


20it [01:55,  5.76s/it]


Epoch 1: train_loss = 1.1131, train_acc = 0.5256, val_loss = 0.9614, val_acc = 0.6068


20it [01:53,  5.68s/it]


Epoch 2: train_loss = 0.9198, train_acc = 0.6255, val_loss = 0.8483, val_acc = 0.6480


20it [01:54,  5.72s/it]


Epoch 3: train_loss = 0.8433, train_acc = 0.6450, val_loss = 0.7808, val_acc = 0.6837


20it [01:56,  5.83s/it]


Epoch 4: train_loss = 0.8071, train_acc = 0.6694, val_loss = 0.7526, val_acc = 0.6855


20it [01:57,  5.88s/it]


Epoch 5: train_loss = 0.7719, train_acc = 0.6795, val_loss = 0.7269, val_acc = 0.7065


20it [01:56,  5.82s/it]


Epoch 6: train_loss = 0.7475, train_acc = 0.6954, val_loss = 0.7137, val_acc = 0.7046


20it [01:56,  5.84s/it]


Epoch 7: train_loss = 0.7328, train_acc = 0.7054, val_loss = 0.7001, val_acc = 0.7157


20it [01:57,  5.86s/it]


Epoch 8: train_loss = 0.7096, train_acc = 0.7105, val_loss = 0.6908, val_acc = 0.7372


20it [01:53,  5.67s/it]


Epoch 9: train_loss = 0.7007, train_acc = 0.7215, val_loss = 0.6959, val_acc = 0.7348


20it [01:51,  5.58s/it]


Epoch 10: train_loss = 0.6977, train_acc = 0.7182, val_loss = 0.6817, val_acc = 0.7329


20it [01:51,  5.58s/it]


Epoch 11: train_loss = 0.6750, train_acc = 0.7335, val_loss = 0.6673, val_acc = 0.7434


20it [01:52,  5.62s/it]


Epoch 12: train_loss = 0.6662, train_acc = 0.7342, val_loss = 0.6776, val_acc = 0.7452


20it [01:49,  5.48s/it]


Epoch 13: train_loss = 0.6776, train_acc = 0.7261, val_loss = 0.6475, val_acc = 0.7495


20it [01:50,  5.54s/it]


Epoch 14: train_loss = 0.6486, train_acc = 0.7450, val_loss = 0.6410, val_acc = 0.7428


20it [01:50,  5.54s/it]


Epoch 15: train_loss = 0.6385, train_acc = 0.7489, val_loss = 0.6400, val_acc = 0.7551


20it [01:49,  5.49s/it]


Epoch 16: train_loss = 0.6391, train_acc = 0.7520, val_loss = 0.6361, val_acc = 0.7563


20it [01:51,  5.58s/it]


Epoch 17: train_loss = 0.6449, train_acc = 0.7412, val_loss = 0.6644, val_acc = 0.7446


20it [01:51,  5.57s/it]


Epoch 18: train_loss = 0.6321, train_acc = 0.7482, val_loss = 0.6200, val_acc = 0.7785


20it [01:54,  5.72s/it]


Epoch 19: train_loss = 0.6198, train_acc = 0.7596, val_loss = 0.6157, val_acc = 0.7698


20it [01:48,  5.43s/it]


Epoch 20: train_loss = 0.6180, train_acc = 0.7615, val_loss = 0.6189, val_acc = 0.7551


20it [01:53,  5.66s/it]


Epoch 21: train_loss = 0.6112, train_acc = 0.7631, val_loss = 0.6099, val_acc = 0.7705


20it [01:49,  5.45s/it]


Epoch 22: train_loss = 0.6067, train_acc = 0.7722, val_loss = 0.6092, val_acc = 0.7649


20it [01:48,  5.40s/it]


Epoch 23: train_loss = 0.6004, train_acc = 0.7701, val_loss = 0.6052, val_acc = 0.7778


20it [01:48,  5.44s/it]


Epoch 24: train_loss = 0.5966, train_acc = 0.7753, val_loss = 0.5978, val_acc = 0.7797


20it [01:46,  5.33s/it]


Epoch 25: train_loss = 0.5868, train_acc = 0.7770, val_loss = 0.6152, val_acc = 0.7686


20it [01:44,  5.24s/it]


Epoch 26: train_loss = 0.5925, train_acc = 0.7755, val_loss = 0.6049, val_acc = 0.7612


20it [01:46,  5.30s/it]


Epoch 27: train_loss = 0.5865, train_acc = 0.7754, val_loss = 0.6099, val_acc = 0.7723


20it [01:45,  5.29s/it]


Epoch 28: train_loss = 0.5826, train_acc = 0.7785, val_loss = 0.5949, val_acc = 0.7822


20it [01:47,  5.39s/it]


Epoch 29: train_loss = 0.5754, train_acc = 0.7857, val_loss = 0.5952, val_acc = 0.7766


20it [01:47,  5.39s/it]


Epoch 30: train_loss = 0.5805, train_acc = 0.7812, val_loss = 0.5924, val_acc = 0.7803


20it [01:49,  5.47s/it]


Epoch 31: train_loss = 0.5702, train_acc = 0.7849, val_loss = 0.5817, val_acc = 0.7840


20it [01:48,  5.43s/it]


Epoch 32: train_loss = 0.5712, train_acc = 0.7850, val_loss = 0.5923, val_acc = 0.7717


20it [01:50,  5.53s/it]


Epoch 33: train_loss = 0.5656, train_acc = 0.7812, val_loss = 0.5806, val_acc = 0.7815


20it [01:48,  5.43s/it]


Epoch 34: train_loss = 0.5648, train_acc = 0.7865, val_loss = 0.5796, val_acc = 0.7858


20it [01:47,  5.38s/it]


Epoch 35: train_loss = 0.5619, train_acc = 0.7913, val_loss = 0.5854, val_acc = 0.7692


20it [01:51,  5.59s/it]


Epoch 36: train_loss = 0.5617, train_acc = 0.7877, val_loss = 0.5885, val_acc = 0.7735


20it [01:56,  5.83s/it]


Epoch 37: train_loss = 0.5621, train_acc = 0.7820, val_loss = 0.5857, val_acc = 0.7846


20it [01:57,  5.88s/it]


Epoch 38: train_loss = 0.5490, train_acc = 0.7942, val_loss = 0.5682, val_acc = 0.7994


20it [01:58,  5.91s/it]


Epoch 39: train_loss = 0.5488, train_acc = 0.7979, val_loss = 0.5688, val_acc = 0.8012


20it [01:57,  5.88s/it]


Epoch 40: train_loss = 0.5458, train_acc = 0.7978, val_loss = 0.5641, val_acc = 0.7902


20it [01:50,  5.53s/it]


Epoch 41: train_loss = 0.5330, train_acc = 0.8012, val_loss = 0.5729, val_acc = 0.7957


20it [01:51,  5.56s/it]


Epoch 42: train_loss = 0.5504, train_acc = 0.7928, val_loss = 0.5715, val_acc = 0.7803


20it [01:48,  5.44s/it]


Epoch 43: train_loss = 0.5636, train_acc = 0.7826, val_loss = 0.5842, val_acc = 0.7809


20it [01:46,  5.33s/it]


Epoch 44: train_loss = 0.5404, train_acc = 0.7986, val_loss = 0.5672, val_acc = 0.7877


20it [01:47,  5.37s/it]


Epoch 45: train_loss = 0.5398, train_acc = 0.8004, val_loss = 0.5615, val_acc = 0.7957


20it [01:47,  5.38s/it]


Epoch 46: train_loss = 0.5388, train_acc = 0.7979, val_loss = 0.5802, val_acc = 0.7828


20it [01:49,  5.46s/it]


Epoch 47: train_loss = 0.5346, train_acc = 0.8000, val_loss = 0.5560, val_acc = 0.7920


20it [01:48,  5.42s/it]


Epoch 48: train_loss = 0.5313, train_acc = 0.8002, val_loss = 0.5812, val_acc = 0.7785


20it [01:47,  5.36s/it]


Epoch 49: train_loss = 0.5340, train_acc = 0.7986, val_loss = 0.5460, val_acc = 0.7988


20it [01:47,  5.39s/it]


Epoch 50: train_loss = 0.5218, train_acc = 0.8100, val_loss = 0.5435, val_acc = 0.7938


20it [01:47,  5.36s/it]


Epoch 51: train_loss = 0.5218, train_acc = 0.8101, val_loss = 0.5512, val_acc = 0.7908


20it [01:50,  5.52s/it]


Epoch 52: train_loss = 0.5274, train_acc = 0.8039, val_loss = 0.5612, val_acc = 0.7902


20it [01:48,  5.43s/it]


Epoch 53: train_loss = 0.5161, train_acc = 0.8129, val_loss = 0.5460, val_acc = 0.7994


20it [01:47,  5.36s/it]


Epoch 54: train_loss = 0.5150, train_acc = 0.8113, val_loss = 0.5439, val_acc = 0.8006


20it [01:49,  5.46s/it]


Epoch 55: train_loss = 0.5111, train_acc = 0.8167, val_loss = 0.5505, val_acc = 0.8006


15it [01:25,  5.73s/it]