In [2]:
import pandas as pd
import os
from PIL import Image
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision import models
import multiprocessing

In [15]:
img_dir = "../data/train_images"
df = pd.read_csv("../data/train.csv")
train_df, val_df = train_test_split(
    df,
    test_size=0.2,
    stratify=df['label'],
    random_state=42
)
train_df.shape, val_df.shape

((8325, 4), (2082, 4))

In [4]:
age_mean = train_df['age'].mean()
age_std = train_df['age'].std()

In [5]:
class PaddyMultimodalDataset(Dataset):
    def __init__(self, df, img_dir, age_mean, age_std, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        
        self.age_mean = age_mean
        self.age_std = age_std

        unique_varieties = self.df['variety'].unique()
        self.variety2idx = {v: i for i, v in enumerate(unique_varieties)}
        
        unique_label = self.df['label'].unique()
        self.label2idx = {l: i for i, l in enumerate(unique_label)}
        self.idx2label = {i: l for i, l in enumerate(unique_label)}

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['label'], row['image_id'])
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        age = (row['age'] - self.age_mean) / self.age_std
        age = torch.tensor(age, dtype=torch.float32)

        variety_id = self.variety2idx[row['variety']]
        variety_id = torch.tensor(variety_id, dtype=torch.long)
        
        label = self.label2idx[row['label']]
        label = torch.tensor(label, dtype=torch.long)
        
        return image, age, variety_id, label

In [7]:
train_transform = transforms.Compose([
    # data augmentations
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=20),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),

    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])


val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], # normalize images to ImageNet mean and std
        std=[0.229, 0.224, 0.225]
    )
])

In [9]:
def train_one_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, ages, varieties, labels in train_loader:
        # move data to the right device
        images = images.to(device)
        labels = labels.to(device)
        ages = ages.to(device)
        varieties = varieties.to(device)

        
        #forward pass
        outputs = model(images, ages, varieties)
        loss = criterion(outputs, labels)
        
        # backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, preds =torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    avg_loss = running_loss / total
    avg_acc = correct / total
    
    return avg_loss, avg_acc


In [10]:
def validate_one_epoch(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, ages, varieties, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            ages = ages.to(device)
            varieties = varieties.to(device)

            outputs = model(images, ages, varieties)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    avg_loss = running_loss / total
    avg_acc = correct / total
    return avg_loss, avg_acc


In [11]:
class PaddyMultimodalModel(nn.Module):
    def __init__(self, num_varieties, num_classes, embedding_dim=8):
        super().__init__()
        
        model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        modules = list(model.children())[:-1]
        self.backbone = nn.Sequential(*modules)

        for param in model.parameters():
            param.requires_grad = False

        self.variety_embedder = nn.Embedding(num_embeddings=num_varieties, embedding_dim=embedding_dim)
        
        in_features = 2048 + 1 + embedding_dim
        self.classifier = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
        
    def forward(self, image, age, variety):
        x_image = self.backbone(image) # (Batch, 2048, 1, 1)
        x_image = torch.flatten(x_image, 1) # (Batch, 2048)

        age = age.view(-1, 1).to(x_image.dtype)  # (Batch, 1)
        x_variety = self.variety_embedder(variety) # (Batch, embedding_dim)

        x = torch.cat([x_image, age, x_variety], dim=1) # (Batch, 2048 + 1 + embedding_dim)
        out = self.classifier(x)
        return out


In [13]:
def train_multimodal_model(
    df_train, df_val, img_dir, 
    train_transform, val_transform, 
    embedding_dim=8, batch_size=32, lr=1e-4,
    weight_decay=1e-4, num_epochs=10, device=None, num_workers=None
):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    age_mean = df_train['age'].mean()
    age_std = df_train['age'].std()
    
    train_dataset = PaddyMultimodalDataset(df_train, img_dir, age_mean, age_std, transform=train_transform)
    val_dataset   = PaddyMultimodalDataset(df_val, img_dir, age_mean, age_std, transform=val_transform)
    
    if num_workers is None:
        num_workers = multiprocessing.cpu_count()
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    num_varieties = df_train['variety'].nunique()
    num_classes   = df_train['label'].nunique()
    
    model = PaddyMultimodalModel(num_varieties=num_varieties, num_classes=num_classes, embedding_dim=embedding_dim)
    model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam([
        {'params': model.variety_embedder.parameters(), 'lr': lr},
        {'params': model.classifier.parameters(), 'lr': lr}
    ], weight_decay=weight_decay)
    
    for epoch in range(num_epochs):
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = validate_one_epoch(model, val_loader, criterion, device)

        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"  Val Loss  : {val_loss:.4f}   | Val Acc  : {val_acc:.4f}")

    return model

In [None]:
model = train_multimodal_model(train_df, val_df, img_dir, train_transform, val_transform, lr=1e-4, num_epochs=10)

Epoch 1/10
  Train Loss: 1.9256 | Train Acc: 0.3511
  Val Loss  : 2.3907   | Val Acc  : 0.1888
Epoch 2/10
  Train Loss: 1.6583 | Train Acc: 0.4389
  Val Loss  : 2.5347   | Val Acc  : 0.1988
