In [1]:
import pandas as pd
import numpy as np
import torch
import os
import random
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

In [2]:
class Lego_Dataset(torch.utils.data.Dataset):
    def __init__(self, file_paths, path, labels, transform=None):
        """
        Args:
            file_paths (list): List of file paths for the images.
            labels (list): List of corresponding labels.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform
        self.path = path

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path = self.file_paths[idx]
        image = Image.open(os.path.join(self.path,img_path)).convert("RGB")

        if self.transform:
            image = self.transform(image)

        label = self.labels[idx]

        return image, label

In [3]:
#list of filenames
wd = os.getcwd()
path = os.path.join(wd,"C:/Users/dchro/Documents/MLOps/mlops_project/data/external/lego_dataset")
print(wd, path)
print(os.path.join(path, 'index.csv'))
index = pd.read_csv( 'C:/Users/dchro/Documents/MLOps/mlops_project/data/external/lego_dataset/index.csv')
labels = index["class_id"]-1
files = index["path"]

c:\Users\dchro\Documents\MLOps\mlops_project\notebooks C:/Users/dchro/Documents/MLOps/mlops_project/data/external/lego_dataset
C:/Users/dchro/Documents/MLOps/mlops_project/data/external/lego_dataset\index.csv


In [7]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = Lego_Dataset(file_paths=files, path = path, labels=labels,transform=transform)

trainset[0][0].shape

In [9]:
train_loader = DataLoader(trainset, batch_size=32, shuffle=True)

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import timm
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset

num_epochs = 5
lr = 0.003

num_classes = 38 


#model = timm.create_model('mobilenetv3_small_100', pretrained=True, num_classes=num_classes)



model = timm.create_model('mobilenetv3_large_100', pretrained=True, num_classes=num_classes)


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)


# Training loop
for ep in range(num_epochs):
    

    total_loss = 0
    num_correct = 0

    for batch_idx, (inputs, labels) in enumerate(train_loader):
        model.train()
        optimizer.zero_grad()
        y_hat = model(inputs)
        batch_loss = criterion(y_hat, labels)
        batch_loss.backward()
        optimizer.step()

        total_loss += float(batch_loss)
        num_correct += int(torch.sum(torch.argmax(y_hat, dim=1) == labels))

        
        print(
            "EPOCH: {:5}/tBATCH: {:5}/{:5}/tLOSS: {:.3f}".format(
                ep, batch_idx, len(train_loader), batch_loss
            )
        )

    epoch_loss = total_loss / len(trainset)
    epoch_accuracy = num_correct / len(trainset)
    print(
        "EPOCH: {:5} /t LOSS: {:.3f} /t ACCURACY: {:.3f}".format(
            ep, epoch_loss, epoch_accuracy
        )
    )


    
    # Validation loop (optional)
    #model.eval()
    #with torch.no_grad():
    #    for inputs, labels in val_loader:
    #        outputs = model(inputs)
            # Calculate validation loss and metrics

# Save the trained model
#torch.save(model.state_dict(), 'mobilenetv3_fine_tuned.pth')




EPOCH:     0/tBATCH:     0/   12/tLOSS: 3.996
EPOCH:     0/tBATCH:     1/   12/tLOSS: 6.326
EPOCH:     0/tBATCH:     2/   12/tLOSS: 3.532
EPOCH:     0/tBATCH:     3/   12/tLOSS: 4.061
EPOCH:     0/tBATCH:     4/   12/tLOSS: 3.867
EPOCH:     0/tBATCH:     5/   12/tLOSS: 3.061
EPOCH:     0/tBATCH:     6/   12/tLOSS: 3.496
EPOCH:     0/tBATCH:     7/   12/tLOSS: 3.318
EPOCH:     0/tBATCH:     8/   12/tLOSS: 3.327
EPOCH:     0/tBATCH:     9/   12/tLOSS: 3.070
EPOCH:     0/tBATCH:    10/   12/tLOSS: 2.603
EPOCH:     0/tBATCH:    11/   12/tLOSS: 1.867
EPOCH:     0 /t LOSS: 0.115 /t ACCURACY: 0.181
EPOCH:     1/tBATCH:     0/   12/tLOSS: 1.402
EPOCH:     1/tBATCH:     1/   12/tLOSS: 1.776
EPOCH:     1/tBATCH:     2/   12/tLOSS: 1.841
EPOCH:     1/tBATCH:     3/   12/tLOSS: 1.062
EPOCH:     1/tBATCH:     4/   12/tLOSS: 1.517
EPOCH:     1/tBATCH:     5/   12/tLOSS: 1.146
EPOCH:     1/tBATCH:     6/   12/tLOSS: 1.101
EPOCH:     1/tBATCH:     7/   12/tLOSS: 1.617
EPOCH:     1/tBATCH:     8/   12/

KeyboardInterrupt: 