In [1]:
import pandas as pd
import numpy as np
import torch
import os
import random
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

In [2]:
class Lego_Dataset(torch.utils.data.Dataset):
    def __init__(self, file_paths, path, labels, transform=None):
        """
        Args:
            file_paths (list): List of file paths for the images.
            labels (list): List of corresponding labels.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform
        self.path = path

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path = self.file_paths[idx]
        image = Image.open(os.path.join(self.path,img_path)).convert("RGB")

        if self.transform:
            image = self.transform(image)

        label = self.labels[idx]

        return image, label

In [3]:
#list of filenames
path = "/Users/LeMarx/Documents/01_Projects/mlops_project/data/external/lego_dataset"
index = pd.read_csv("/Users/LeMarx/Documents/01_Projects/mlops_project/data/external/lego_dataset/index.csv")
labels = index["class_id"]-1
files = index["path"]

In [4]:
labels

0       0
1       0
2       0
3       0
4       0
       ..
366    37
367    37
368    37
369    37
370    37
Name: class_id, Length: 371, dtype: int64

In [5]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

In [6]:
trainset = Lego_Dataset(file_paths=files, path = path, labels=labels,transform=transform)

In [7]:
trainset[0][0].shape

torch.Size([3, 512, 512])

In [8]:
train_loader = DataLoader(trainset, batch_size=32, shuffle=True)

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import timm
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset

num_epochs = 5
lr = 0.003

num_classes = 38 


#model = timm.create_model('mobilenetv3_small_100', pretrained=True, num_classes=num_classes)

model = nn.Sequential(
    nn.Conv2d(in_channels= 3, out_channels=32,kernel_size=3),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(in_channels=32,out_channels=16,kernel_size=3),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(in_features=254016,out_features=num_classes)
)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)


# Adjust this according to your needs


for ep in range(num_epochs):
    

    total_loss = 0
    num_correct = 0

    for batch_idx, (inputs, labels) in enumerate(train_loader):
        model.train()
        optimizer.zero_grad()
        y_hat = model(inputs)
        batch_loss = criterion(y_hat, labels)
        batch_loss.backward()
        optimizer.step()

        total_loss += float(batch_loss)
        num_correct += int(torch.sum(torch.argmax(y_hat, dim=1) == labels))

        
        print(
            "EPOCH: {:5}\tBATCH: {:5}/{:5}\tLOSS: {:.3f}".format(
                ep, batch_idx, len(train_loader), batch_loss
            )
        )

    epoch_loss = total_loss / len(trainset)
    epoch_accuracy = num_correct / len(trainset)
    print(
        "EPOCH: {:5}\tLOSS: {:.3f}\tACCURACY: {:.3f}".format(
            ep, epoch_loss, epoch_accuracy
        )
    )


    
    # Validation loop (optional)
    #model.eval()
    #with torch.no_grad():
    #    for inputs, labels in val_loader:
    #        outputs = model(inputs)
            # Calculate validation loss and metrics

# Save the trained model
#torch.save(model.state_dict(), 'mobilenetv3_fine_tuned.pth')




EPOCH:     0	BATCH:     0/   12	LOSS: 3.649
EPOCH:     0	BATCH:     1/   12	LOSS: 74.196
EPOCH:     0	BATCH:     2/   12	LOSS: 41.695
EPOCH:     0	BATCH:     3/   12	LOSS: 11.657
EPOCH:     0	BATCH:     4/   12	LOSS: 3.670
EPOCH:     0	BATCH:     5/   12	LOSS: 3.639
EPOCH:     0	BATCH:     6/   12	LOSS: 3.637
EPOCH:     0	BATCH:     7/   12	LOSS: 3.639
EPOCH:     0	BATCH:     8/   12	LOSS: 3.639
EPOCH:     0	BATCH:     9/   12	LOSS: 3.637
EPOCH:     0	BATCH:    10/   12	LOSS: 3.636
EPOCH:     0	BATCH:    11/   12	LOSS: 3.637
EPOCH:     0	LOSS: 0.432	ACCURACY: 0.022
EPOCH:     1	BATCH:     0/   12	LOSS: 3.637
EPOCH:     1	BATCH:     1/   12	LOSS: 3.637
EPOCH:     1	BATCH:     2/   12	LOSS: 3.638
EPOCH:     1	BATCH:     3/   12	LOSS: 3.636
EPOCH:     1	BATCH:     4/   12	LOSS: 3.640
EPOCH:     1	BATCH:     5/   12	LOSS: 3.635
EPOCH:     1	BATCH:     6/   12	LOSS: 3.637


KeyboardInterrupt: 