In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import transforms, Lambda
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = mobilenet_v3_small(MobileNet_V3_Small_Weights.DEFAULT)

model.to(device)

transform = MobileNet_V3_Small_Weights.DEFAULT.transforms()
transform = transforms.Compose([
    Lambda(lambda x: x.convert("RGB")), 
    MobileNet_V3_Small_Weights.DEFAULT.transforms()
])

train_dataset = MNIST(
    root="../data",
    train=True,
    download=True,
    transform=transform
)

test_dataset = MNIST(
    root="../data",
    train=False,
    download=True,
    transform=transform
)

train_dataloader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)

In [None]:
from torch.optim import Adam
import torch.nn as nn

num_classes = 10

model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes)

loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)



In [None]:
for i, layer in enumerate(model.features):
    if i < 5: 
        for param in layer.parameters():
            param.requires_grad = False

for name, param in model.named_parameters():
    print(name, param.requires_grad)

optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr= 0.001)


In [None]:
for parameter in model.features.parameters():
    parameter.requires_grad = False

model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes)

optimizer = Adam(params=filter(lambda param: param.requires_grad, model.parameters()), lr=0.001)



In [None]:
model.train()
num_batches = len(train_dataloader)
expected_loss, num_correct , samples = 0 , 0 , 0
for batch, (X, y) in enumerate(train_dataloader):
    X, y = X.to(device), y.to(device)

    optimizer.zero_grad()

    pred = model(X)
    loss = loss_fn(pred, y)

    loss.backward()
    optimizer.step()

    expected_loss += loss.item()
    pred_idx = torch.argmax(pred, dim=1)
    num_correct += ((pred_idx == y).sum()).item()
    samples += y.size(0)

    
    if not (batch % 100):
        print(f"Batch: {batch} ----------------")
        print(f"Loss = {expected_loss / (batch+1)}")
        print(f"Accuracy = {num_correct / samples}")

print(f"total loss: {expected_loss/num_batches}")