In [2]:
from src.models import ResNet50
from src.dataset import Animals
from src.train import Train
from src.helper import accuracy_fn

In [3]:
import torch 
from torch import nn 
from torch.utils.data import DataLoader

In [4]:
torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
# Parameters
BATCH_SIZE = 128
EPOCHS = 50

In [6]:
train_data,val_data,test_data = Animals.train_test_val_data()

In [7]:
animals = Animals(root='src/data/animal_data')

In [8]:
class_names = animals.get_class_to_idx()

In [9]:
class_names

{'Bear': 0,
 'Bird': 1,
 'Cat': 2,
 'Cow': 3,
 'Deer': 4,
 'Dog': 5,
 'Dolphin': 6,
 'Elephant': 7,
 'Giraffe': 8,
 'Horse': 9,
 'Kangaroo': 10,
 'Lion': 11,
 'Panda': 12,
 'Tiger': 13,
 'Zebra': 14}

In [10]:
train_loader = DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)
val_loader = DataLoader(val_data,batch_size=BATCH_SIZE,shuffle=False)
test_loader = DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=False)

In [11]:
model = ResNet50(num_classes=15).to(device)

In [12]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(),lr=0.00001)

In [9]:
trainer = Train(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    loss_fn=loss_fn,
    optimizer=optimizer,
    accuracy_fn=accuracy_fn,
    device=device,
    l2_lambda=0.001
)

In [10]:
trainer.train(num_epochs=EPOCHS)

Train Epoch 1: Loss: 3.65750 | Accuracy: 8.78
Test Epoch 1: Loss: 2.73726 | Accuracy: 5.97
Train Epoch 2: Loss: 3.42089 | Accuracy: 12.37
Test Epoch 2: Loss: 2.73966 | Accuracy: 5.97
Train Epoch 3: Loss: 3.25339 | Accuracy: 24.09
Test Epoch 3: Loss: 2.66791 | Accuracy: 10.97
Train Epoch 4: Loss: 3.12065 | Accuracy: 32.24
Test Epoch 4: Loss: 2.40516 | Accuracy: 23.45
Train Epoch 5: Loss: 2.97101 | Accuracy: 39.32
Test Epoch 5: Loss: 2.11773 | Accuracy: 41.52
Train Epoch 6: Loss: 2.84287 | Accuracy: 43.11
Test Epoch 6: Loss: 1.93983 | Accuracy: 50.11
Train Epoch 7: Loss: 2.67944 | Accuracy: 51.61
Test Epoch 7: Loss: 1.73022 | Accuracy: 58.57
Train Epoch 8: Loss: 2.53190 | Accuracy: 57.97
Test Epoch 8: Loss: 1.50541 | Accuracy: 66.11
Train Epoch 9: Loss: 2.36614 | Accuracy: 62.60
Test Epoch 9: Loss: 1.35194 | Accuracy: 67.66
Train Epoch 10: Loss: 2.19831 | Accuracy: 68.79
Test Epoch 10: Loss: 1.26204 | Accuracy: 70.36
Train Epoch 11: Loss: 1.96504 | Accuracy: 75.91
Test Epoch 11: Loss: 1.

In [11]:
# Save the model
torch.save(model.state_dict(), 'model.pth')

In [13]:
# Load the model parameters
model.load_state_dict(torch.load('model.pth'))

<All keys matched successfully>