In [1]:
import pandas as pd
import numpy as np
import torch
import os
import random
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

In [2]:
class Lego_Dataset(torch.utils.data.Dataset):
    def __init__(self, file_paths, path, labels, transform=None):
        """
        Args:
            file_paths (list): List of file paths for the images.
            labels (list): List of corresponding labels.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform
        self.path = path

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path = self.file_paths[idx]
        image = Image.open(os.path.join(self.path,img_path)).convert("RGB")

        if self.transform:
            image = self.transform(image)

        label = self.labels[idx]

        return image, label

In [13]:
#list of filenames
parent = os.path.dirname(os.path.dirname(os.getcwd()))
print(parent)
path_index = Path('data/external/lego_dataset/index.csv')
print(os.path.join(parent,path_index))
#path = "C:/Users/Lennart/Documents/GitHub/mlops_project/data/external/lego_dataset"
index = pd.read_csv(os.path.join(parent,path_index))

index
#labels = index["class_id"]-1
#files = index["path"]

c:\Users\Lennart\Documents\GitHub\mlops_project
c:\Users\Lennart\Documents\GitHub\mlops_project\data\external\lego_dataset\index.csv


Unnamed: 0,path,class_id
0,marvel/0001/001.jpg,1
1,marvel/0001/002.jpg,1
2,marvel/0001/003.jpg,1
3,marvel/0001/004.jpg,1
4,marvel/0001/005.jpg,1
...,...,...
366,star-wars/0017/006.jpg,38
367,star-wars/0017/007.jpg,38
368,star-wars/0017/008.jpg,38
369,star-wars/0017/009.jpg,38


In [46]:
print("Hello World")

Hello World


In [30]:
labels

0       0
1       0
2       0
3       0
4       0
       ..
366    37
367    37
368    37
369    37
370    37
Name: class_id, Length: 371, dtype: int64

In [35]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

In [36]:
trainset = Lego_Dataset(file_paths=files, path = path, labels=labels,transform=transform)

In [37]:
trainset[0][0].shape

torch.Size([3, 512, 512])

In [38]:
train_loader = DataLoader(trainset, batch_size=32, shuffle=True)

In [45]:
import torch
import torch.nn as nn
import torch.optim as optim
import timm
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset

num_epochs = 5
lr = 0.003

num_classes = 38 


#model = timm.create_model('mobilenetv3_small_100', pretrained=True, num_classes=num_classes)



model = timm.create_model('mobilenetv3_large_100', pretrained=True, num_classes=num_classes)


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)


# Adjust this according to your needs


for ep in range(num_epochs):
    

    total_loss = 0
    num_correct = 0

    for batch_idx, (inputs, labels) in enumerate(train_loader):
        model.train()
        optimizer.zero_grad()
        y_hat = model(inputs)
        batch_loss = criterion(y_hat, labels)
        batch_loss.backward()
        optimizer.step()

        total_loss += float(batch_loss)
        num_correct += int(torch.sum(torch.argmax(y_hat, dim=1) == labels))

        
        print(
            "EPOCH: {:5}/tBATCH: {:5}/{:5}/tLOSS: {:.3f}".format(
                ep, batch_idx, len(train_loader), batch_loss
            )
        )

    epoch_loss = total_loss / len(trainset)
    epoch_accuracy = num_correct / len(trainset)
    print(
        "EPOCH: {:5}/tLOSS: {:.3f}/tACCURACY: {:.3f}".format(
            ep, epoch_loss, epoch_accuracy
        )
    )


    
    # Validation loop (optional)
    #model.eval()
    #with torch.no_grad():
    #    for inputs, labels in val_loader:
    #        outputs = model(inputs)
            # Calculate validation loss and metrics

# Save the trained model
#torch.save(model.state_dict(), 'mobilenetv3_fine_tuned.pth')




EPOCH:     0/tBATCH:     0/   12/tLOSS: 3.745
EPOCH:     0/tBATCH:     1/   12/tLOSS: 7.821
EPOCH:     0/tBATCH:     2/   12/tLOSS: 4.039
EPOCH:     0/tBATCH:     3/   12/tLOSS: 4.437
EPOCH:     0/tBATCH:     4/   12/tLOSS: 3.551
EPOCH:     0/tBATCH:     5/   12/tLOSS: 3.106
EPOCH:     0/tBATCH:     6/   12/tLOSS: 3.540
EPOCH:     0/tBATCH:     7/   12/tLOSS: 2.220
EPOCH:     0/tBATCH:     8/   12/tLOSS: 3.269
EPOCH:     0/tBATCH:     9/   12/tLOSS: 3.076
EPOCH:     0/tBATCH:    10/   12/tLOSS: 2.372
EPOCH:     0/tBATCH:    11/   12/tLOSS: 2.769
EPOCH:     0/tLOSS: 0.118/tACCURACY: 0.194
EPOCH:     1/tBATCH:     0/   12/tLOSS: 1.772
EPOCH:     1/tBATCH:     1/   12/tLOSS: 1.713
EPOCH:     1/tBATCH:     2/   12/tLOSS: 1.189
EPOCH:     1/tBATCH:     3/   12/tLOSS: 1.006
EPOCH:     1/tBATCH:     4/   12/tLOSS: 0.682
EPOCH:     1/tBATCH:     5/   12/tLOSS: 1.203
EPOCH:     1/tBATCH:     6/   12/tLOSS: 0.670
EPOCH:     1/tBATCH:     7/   12/tLOSS: 0.644
EPOCH:     1/tBATCH:     8/   12/tLOS