In [None]:
import torch
from torch import nn
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
from PIL import Image
from tqdm import tqdm

class Caltech200Dataset(Dataset):
    def __init__(self, root_dir, txt_file, transform=None, num_augmentations=1):
        self.root_dir = root_dir
        self.transform = transform
        self.num_augmentations = num_augmentations
        self.image_paths = []
        self.labels = []

        with open(txt_file, 'r') as f:
            for line in f:
                relative_path = line.strip()
                full_path = os.path.join(root_dir, relative_path)
                self.image_paths.append(full_path)
                # Extract class label from the path (assuming format "001.Class_Name/image.jpg")
                label = int(relative_path.split('/')[0].split('.')[0]) - 1  # Subtract 1 to start from 0
                self.labels.append(label)

    def __len__(self):
        return len(self.image_paths) * self.num_augmentations

    def __getitem__(self, idx):
        true_idx = idx//self.num_augmentations
        img_path = self.image_paths[true_idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[true_idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
# Set up environment
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Prepare the dataset
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomVerticalFlip(0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(10),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
root_dir = '/home/feem/Workspace/caltech_birds/images'
train_dataset = Caltech200Dataset(root_dir=root_dir, 
                                  txt_file='/home/feem/Workspace/caltech_birds/lists/train.txt', 
                                  transform=train_transform,
                                  num_augmentations=20)
test_dataset = Caltech200Dataset(root_dir=root_dir, 
                                 txt_file='/home/feem/Workspace/caltech_birds/lists/test.txt',
                                 transform=test_transform,
                                 num_augmentations=1)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=12)
test_loader = DataLoader(test_dataset,   batch_size=64, shuffle=False, num_workers=12)

# Define the model
model = torchvision.models.resnet18(pretrained=True)

# for param in model.parameters():
#     param.requires_grad = False

n_inputs = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(n_inputs, 1024),  # Increase the size of the first fully connected layer
    nn.SiLU(),
    nn.Dropout(0.5),
    nn.Linear(1024, 2048),  # Add another fully connected layer
    nn.SiLU(),
    nn.Dropout(0.5),
    nn.Linear(2048, 2048),  # Add another fully connected layer
    nn.SiLU(),
    nn.Dropout(0.5),
    nn.Linear(2048, 200)  # Adjust the output size to match the number of classes
)
model = model.to(device)

# Set up training loop
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# Add these imports
from time import time
from collections import defaultdict

# Training and evaluation functions
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    progress_bar = tqdm(loader, desc="Training", leave=False)
    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        progress_bar.set_postfix({'loss': loss.item(), 'accuracy': 100 * correct / total})
    
    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc

def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        progress_bar = tqdm(loader, desc="Evaluating", leave=False)
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            progress_bar.set_postfix({'loss': loss.item(), 'accuracy': 100 * correct / total})
    
    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc

# Training loop
num_epochs = 10
best_acc = 0.0
history = defaultdict(list)

for epoch in range(num_epochs):
    start_time = time()
    
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = evaluate(model, test_loader, criterion, device)
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    epoch_time = time() - start_time
    
    print(f"Epoch [{epoch+1}/{num_epochs}] "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}% | "
          f"Time: {epoch_time:.2f}s")
    
    # Save the best model
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), 'best_model_0.pth')
        print(f"Best model saved with accuracy: {best_acc:.2f}%")

# Save the final model
torch.save(model.state_dict(), 'final_model_0.pth')

print(f"Training completed. Best accuracy: {best_acc:.2f}%")

In [None]:
print('previous best 42.27')