# Set up

In [None]:
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchtune.modules import get_cosine_schedule_with_warmup
import pandas as pd
from tqdm import tqdm
from PIL import Image
from models.resnet import ResNet101
from models.vision_transformer import VisionTransformer

In [None]:
def load_dataset(root='./data', batch_size=64) -> tuple[DataLoader, DataLoader, DataLoader]:
    torch.manual_seed(42)
    
    data_augmentation = transforms.Compose([
		transforms.RandomRotation(degrees=15),
		transforms.RandomHorizontalFlip(p=0.5),
		transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
	])
    normalize = transforms.Compose([
		transforms.Resize(size=(224, 224), antialias=True),
		transforms.ToTensor(),
		transforms.Normalize([0.7037, 0.6818, 0.6685], [0.2739, 0.2798, 0.2861]),
	])

    train_dataset = ImageFolder(
        root=os.path.join(root, "train"),
        transform=transforms.Compose([data_augmentation, normalize]),
    )
    valid_dataset = ImageFolder(
        root=os.path.join(root, "val"),
        transform=normalize
    )
    test_dataset = ImageFolder(
        root=os.path.join(root, "test"),
        transform=normalize
    )

    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True
    )
    valid_loader = DataLoader(
        dataset=valid_dataset,
        batch_size=batch_size,
        shuffle=False
    )
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        shuffle=False
    )

    return train_loader, valid_loader, test_loader

In [5]:
train_loader, valid_loader, test_loader = load_dataset()

In [None]:
def train(
	model: nn.Module,
	save_path: str,
	train_loader: DataLoader,
	val_loader: DataLoader,
	num_warmup_steps=5,
	num_epochs=100,
	lr=0.01,
	momentum=0.9,
	weight_decay=0.0005,
	device='cpu',
):
    model = model.to(device)

    result = []
    
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_epochs)
    
    print(f'Training with {device}')

    for epoch in range(num_epochs):
        # train
        model.train()
        train_running_loss=0.0
        train_correct = 0
        with tqdm(total=len(train_loader), desc=f'Train Epoch {epoch+1}/{num_epochs}', unit='batch') as pbar:
            for i, (inputs, labels) in enumerate(train_loader):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = loss_fn(outputs, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                train_correct += (predicted == labels).sum().item()
                pbar.set_postfix({'loss': f'{train_running_loss/(i+1):.3f}'})
                pbar.update()
            train_loss = train_running_loss/(i+1)
            train_acc = train_correct/len(train_loader.dataset)*100
            pbar.set_postfix({'loss': f'{train_loss:.3f}', 'acc': f'{train_acc:.2f}'})

        # val
        model.eval()
        val_running_loss = 0.0
        val_correct = 0
        with tqdm(total=len(val_loader), desc=f'Val Epoch {epoch+1}/{num_epochs}', unit='batch') as pbar:
            with torch.no_grad():
                for i, (inputs, labels) in enumerate(val_loader):
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = model(inputs)
                    loss = loss_fn(outputs, labels)
                    val_running_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    val_correct += (predicted == labels).sum().item()
                    pbar.set_postfix({'loss': f'{val_running_loss/(i+1):.3f}'})
                    pbar.update()
                val_loss = val_running_loss/(i+1)
                val_acc = val_correct/len(val_loader.dataset)*100
            pbar.set_postfix({'loss': f'{val_loss:.3f}', 'acc': f'{val_acc:.2f}'})

        scheduler.step()
        
        torch.save(model.state_dict(), f'{save_path}/{epoch+1}_weights.pth')

        result.append({'epoch': epoch+1, 'train_loss': train_loss, 'train_acc': train_acc, 'val_loss': val_loss, 'val_acc': val_acc})
        result_df = pd.DataFrame(result)
        result_df.to_csv(f'{save_path}/train_result.csv', index=False)

In [None]:
def test(
	model:nn.Module,
	test_loader: DataLoader,
	weights_path: str,
	device='cpu',
):
	model = model.to(device).eval()
	model.load_state_dict(torch.load(weights_path, weights_only=True))
	correct = 0
	with tqdm(total=len(test_loader), desc=f'Test', unit='batch') as pbar:
		with torch.inference_mode():
			for i, (inputs, labels) in enumerate(test_loader):
				inputs, labels = inputs.to(device), labels.to(device)
				outputs = model(inputs)
				_, predicted = torch.max(outputs.data, 1)
				correct += (predicted == labels).sum().item()
				pbar.update()
			pbar.set_postfix({'acc': correct/len(test_loader.dataset)*100})

	print(f'\nAccuracy: {correct/len(test_loader.dataset)*100}%')

In [None]:
def itest(
	model:nn.Module,
	image_path: str,
	weights_path: str,
	device='cpu',
):
    classes = [
        'beauty_products',
 		'electronics',
 		'fashion',
 		'fitness_equipments',
 		'furniture',
 		'home_appliances',
 		'kitchenware',
 		'musical_instruments',
 		'study_things',
 		'toys'
    ]
    model = model.to(device).eval()
    model.load_state_dict(torch.load(weights_path, weights_only=True))

    input_image = Image.open(image_path)
    preprocess = transforms.Compose([
            transforms.Resize(size=(224, 224), antialias=True),
            transforms.ToTensor(),
            transforms.Normalize([0.7037, 0.6818, 0.6685], [0.2739, 0.2798, 0.2861]),
        ])
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0).to(device)

    with torch.inference_mode():
        output = model(input_batch)
    probs = torch.nn.functional.softmax(output[0], dim=0)
    for idx, prob in enumerate(probs):
        print(f'{classes[idx]}: {prob*100:.2f}%')

    print(f'\nPrediction: {classes[torch.argmax(probs)]}')

# Train Test ResNet

In [None]:
model = ResNet101(num_classes=10)

In [None]:
train_loader, val_loader, test_loader = load_dataset()

In [None]:
train(model, 'resnet101', train_loader, val_loader, device='cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
test(model, 'resnet101', test_loader, device='cuda' if torch.cuda.is_available() else 'cpu')

# Train Test ViT

In [None]:
model = VisionTransformer(num_classes=10)

In [None]:
train_loader, val_loader, test_loader = load_dataset()

In [None]:
train(model, 'Vit-B/16', train_loader, val_loader, device='cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
test(model, 'Vit-B/16', test_loader, device='cuda' if torch.cuda.is_available() else 'cpu')