In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image, UnidentifiedImageError
from torchvision import datasets, models, transforms

In [2]:
root_dir = "../input/microsoft-catsvsdogs-dataset/PetImages"

In [3]:
class CatsAndDogsDataset(Dataset):
    def __init__(self, root_dir, transform=None, start=0, finish=1000):
        self.root_dir = root_dir
        self.transform = transform
        
        self.dog_files = os.listdir(os.path.join(root_dir, "Dog"))[start:finish]
        self.cat_files = os.listdir(os.path.join(root_dir, "Cat"))[start:finish]
        
        self.length = min(len(self.dog_files), len(self.cat_files))
    
    def __len__(self):
        return self.length * 2
    

    def __getitem__(self, idx):
        try:
            if idx % 2 == 0:
                folder = "Dog"
                image_files = self.dog_files
                label = 1
            else:
                folder = "Cat"
                image_files = self.cat_files
                label = 0  # Cat label
            
            adjusted_idx = idx // 2
            img_path = os.path.join(self.root_dir, folder, image_files[adjusted_idx])
            
            image = Image.open(img_path).convert("RGB")
            
            if self.transform:
                image = self.transform(image)
            
            return image, label
            
        except (UnidentifiedImageError, OSError) as e:
            print(f"Skipping corrupted image: {img_path}")
            return self.__getitem__((idx + 2) % len(self))


In [4]:
transform = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

In [5]:
bs = 64

train_dataset = CatsAndDogsDataset(root_dir=root_dir, transform=transform['train'],finish=8000)
val_dataset = CatsAndDogsDataset(root_dir=root_dir, transform=transform['val'], start=8000, finish=9000)

train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=bs, shuffle=False)

In [6]:
model = models.resnet18(pretrained=True)



In [7]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [8]:
num_features = model.fc.in_features
num_features

512

In [9]:
model.fc = nn.Linear(num_features, 2)  # 2 classes: dog, cat

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [11]:
# # Freeze all layers except the final fully connected layer
# for param in model.parameters():
#     param.requires_grad = False  # Freeze all parameters

# # Unfreeze the final layer
# model.fc = nn.Linear(model.fc.in_features, 2)
# model.fc.requires_grad = True
# model = model.to(device)

In [12]:
# # Freeze all layers initially
# for param in model.parameters():
#     param.requires_grad = False

# # Unfreeze the last few layers
# for name, param in model.named_parameters():
#     if "layer4" in name or "fc" in name:  # Unfreeze layer4 (last conv block) and fc
#         param.requires_grad = True

# # Update model to the device
# model = model.to(device)


In [13]:
%%time

epochs = 1
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}, Accuracy: {correct/total:.2f}")



Skipping corrupted image: ../input/microsoft-catsvsdogs-dataset/PetImages/Dog/11702.jpg
Epoch 1/1, Loss: 0.14303743577003478, Accuracy: 0.94
CPU times: user 2min 53s, sys: 15.2 s, total: 3min 8s
Wall time: 1min 41s


In [14]:
model.eval()
val_correct = 0
val_total = 0

with torch.no_grad():
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        val_correct += (preds == labels).sum().item()
        val_total += labels.size(0)

print(f"Validation Accuracy: {val_correct/val_total:.2f}")


Skipping corrupted image: ../input/microsoft-catsvsdogs-dataset/PetImages/Dog/Thumbs.db
Skipping corrupted image: ../input/microsoft-catsvsdogs-dataset/PetImages/Cat/Thumbs.db
Validation Accuracy: 0.93


In [None]:
dog = f'{root_dir}/Dog/3863.jpg'
cat = f'{root_dir}/Cat/5307.jpg'

In [None]:
def return_pred(img_path: str):
    preprocessed_image = transform['val'](img).unsqueeze(0)
    return model(preprocessed_image.to(device)).argmax(1)    

In [None]:
Image.open(dog)

In [None]:
return_pred(dog)

In [None]:
Image.open(cat)

In [None]:
return_pred(cat)

When should I only train the last linear layer?
1. Small dataset - to avoid overfitting
2. Similar task - it might not make sense
3. When I don't want to spend much computations