In [2]:
# 0. Setup

import torch
import torchvision
from torchvision import models, transforms
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from nested_image_folder import NestedImageFolder
import matplotlib.pyplot as plt
import os

In [3]:
#1. Load Data

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

data_dir = "/Users/luizacomanescu/git/bali-style-net/dataset/"  # path to your dataset
dataset = NestedImageFolder(root=data_dir, transform=transform)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
class_names = dataset.classes

In [4]:
#2. Load Pretrained ResNet50

model = models.resnet50(pretrained=True)
for param in model.parameters():
    param.requires_grad = False  # freeze all layers

# Replace final layer
num_classes = len(class_names)
model.fc = nn.Linear(model.fc.in_features, num_classes)



In [6]:
#3. Training Setup

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)

AttributeError: 'SwinTransformer' object has no attribute 'fc'

In [7]:
#4. Training Loop

num_epochs = 50
for epoch in range(num_epochs):
    running_loss = 0.0
    model.train()
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {running_loss/len(train_loader):.4f}")

Epoch 1/50 - Loss: 0.7128
Epoch 2/50 - Loss: 0.6433
Epoch 3/50 - Loss: 0.6145
Epoch 4/50 - Loss: 0.5811
Epoch 5/50 - Loss: 0.5712
Epoch 6/50 - Loss: 0.5507
Epoch 7/50 - Loss: 0.5427
Epoch 8/50 - Loss: 0.5358
Epoch 9/50 - Loss: 0.5004
Epoch 10/50 - Loss: 0.4912
Epoch 11/50 - Loss: 0.4475
Epoch 12/50 - Loss: 0.4008
Epoch 13/50 - Loss: 0.3970
Epoch 14/50 - Loss: 0.4005
Epoch 15/50 - Loss: 0.3649
Epoch 16/50 - Loss: 0.3754
Epoch 17/50 - Loss: 0.3891
Epoch 18/50 - Loss: 0.3399
Epoch 19/50 - Loss: 0.3964
Epoch 20/50 - Loss: 0.3560
Epoch 21/50 - Loss: 0.3589
Epoch 22/50 - Loss: 0.3251
Epoch 23/50 - Loss: 0.3018
Epoch 24/50 - Loss: 0.2821
Epoch 25/50 - Loss: 0.2748
Epoch 26/50 - Loss: 0.2658
Epoch 27/50 - Loss: 0.2776
Epoch 28/50 - Loss: 0.2652
Epoch 29/50 - Loss: 0.2573
Epoch 30/50 - Loss: 0.2592
Epoch 31/50 - Loss: 0.2578
Epoch 32/50 - Loss: 0.2303
Epoch 33/50 - Loss: 0.2486
Epoch 34/50 - Loss: 0.2423
Epoch 35/50 - Loss: 0.2579
Epoch 36/50 - Loss: 0.2234
Epoch 37/50 - Loss: 0.2189
Epoch 38/5

In [13]:
import timm
import torch.nn as nn

# Create Swin model with number of classes equal to your dataset's classes
model = timm.create_model(
    'swin_base_patch4_window7_224',
    pretrained=True,
    num_classes=num_classes  # this is key!
)

In [16]:
import torch
import torch.optim as optim
import torch.nn as nn

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

# Training loop (simplified)
epochs = 50
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{epochs} - Loss: {running_loss/len(train_loader):.4f}")

Epoch 1/50 - Loss: 1.1322
Epoch 2/50 - Loss: 0.3245
Epoch 3/50 - Loss: 0.1171
Epoch 4/50 - Loss: 0.0684


KeyboardInterrupt: 

In [15]:
x = torch.randn(4, 3, 224, 224)  # sample input
y = model(x)
print(y.shape)  # should be [4, num_classes]

torch.Size([4, 6])


In [14]:
print(num_classes)

6


In [19]:
# # Save ResNet model
# torch.save(resnet_model.state_dict(), 'resnet_model.pth')

# Save Swin Transformer model
torch.save(model.state_dict(), 'swin_model.pth')