In [8]:
import torch
import torch.nn as nn
import torchvision

from PIL import Image
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet18

In [2]:
DATASET_NAME = 'cats_vs_dogs'
datasets = load_dataset ( DATASET_NAME )
datasets

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


DatasetDict({
    train: Dataset({
        features: ['image', 'labels'],
        num_rows: 23410
    })
})

In [3]:
TEST_SIZE = 0.2
datasets = datasets["train"].train_test_split(test_size=TEST_SIZE)

In [5]:
IMG_SIZE = 64
img_transform = transforms.Compose(
    [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

In [13]:
class CatDogDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        images = self.data[idx]["image"]
        labels = self.data[idx]["labels"]
        
        if self.transform:
            images = self.transform(images)
            
        labels = torch.tensor(labels, dtype=torch.long)
        
        return images, labels

In [14]:
TRAIN_BATCH_SIZE = 512
VAL_BATCH_SIZE = 256

train_dataset = CatDogDataset(datasets["train"], img_transform)
test_dataset = CatDogDataset(datasets["test"], img_transform)

train_loader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=VAL_BATCH_SIZE, shuffle=False)

In [15]:
class CatDogModel(nn.Module):
    def __init__(self, n_classes):
        super(CatDogModel, self).__init__()
        
        resnet_model = torchvision.models.resnet18(pretrained=True)
        self.backbone = nn.Sequential(*list(resnet_model.children())[:-1])
        for param in self.backbone.parameters():
            param.requires_grad = False
        
        in_features = resnet_model.fc.in_features
        self.fc = nn.Linear(in_features, n_classes)
        
    def forward(self, x):
        x = self.backbone(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [16]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
N_Classes = 2
model = CatDogModel(N_Classes).to(device)
test_input = torch.randn(1, 3, 224, 224).to(device)
with torch.no_grad():
    test_output = model(test_input)
    print(test_output, test_output.shape)

tensor([[1.0083, 0.5857]], device='cuda:0') torch.Size([1, 2])




In [17]:
EPOCHS = 100
LR = 1e-3
WEIGHT_DECAY = 1e-4

optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
criterion = nn.CrossEntropyLoss()

for epoch in range(EPOCHS):
    train_losses = []
    model.train()
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_losses.append(loss.item())
        
    train_loss = sum(train_losses) / len(train_losses)
    
    val_losses = []
    model.eval()
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            val_losses.append(loss.item())
            
    val_loss = sum(val_losses) / len(val_losses)
    
    print(f'Epoch: {epoch + 1}, Train Loss: {train_loss}, Val Loss: {val_loss}')

Epoch: 1, Train Loss: 0.6511853095647451, Val Loss: 0.6124513525711862
Epoch: 2, Train Loss: 0.5461576226595286, Val Loss: 0.5547299447812533
Epoch: 3, Train Loss: 0.533674554244892, Val Loss: 0.5335484978399778
Epoch: 4, Train Loss: 0.515381004359271, Val Loss: 0.5251000664736095
Epoch: 5, Train Loss: 0.509812476667198, Val Loss: 0.5254620846949125
Epoch: 6, Train Loss: 0.5083155068191322, Val Loss: 0.5203162619942113
Epoch: 7, Train Loss: 0.50482288808436, Val Loss: 0.5225035783491636
Epoch: 8, Train Loss: 0.5046641456114279, Val Loss: 0.5164312783040499
Epoch: 9, Train Loss: 0.5092968723258456, Val Loss: 0.519756942987442
Epoch: 10, Train Loss: 0.5032809843888154, Val Loss: 0.5167634095016279
Epoch: 11, Train Loss: 0.4982039203514924, Val Loss: 0.5156522769677011
Epoch: 12, Train Loss: 0.4998998827225453, Val Loss: 0.5166904236141004
Epoch: 13, Train Loss: 0.49590026285197286, Val Loss: 0.5155981779098511
Epoch: 14, Train Loss: 0.49597351776586995, Val Loss: 0.512920302780051
Epoch:

In [None]:
SAVE_PATH = 'catdog_model.pt'
torch.save(model.state_dict(), SAVE_PATH)