In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from sklearn.metrics import f1_score
from torchvision import transforms
from torchvision.models import efficientnet_b2
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from tqdm import tqdm
from torch.utils.data import Dataset,DataLoader
from PIL import Image
from torchvision import transforms
from transformers import BertTokenizer
import os
from torchvision.models import resnet50,resnet18 
from transformers import AutoModel, AutoTokenizer

In [None]:
dir_path="/kaggle/input/multi-label-classification-competition-2025/COMP5329S1A2Dataset/"
train_file_name="train.csv"
test_file_name ="test.csv"
train_csv = pd.read_csv(os.path.join(dir_path,train_file_name),usecols=[0,1,2])
test_csv = pd.read_csv(os.path.join(dir_path,test_file_name),usecols=[0,1])

In [None]:
all_label_lists = train_csv['Labels'].apply(lambda x: list(map(int, x.split())))
flat_labels = [label for sublist in all_label_lists for label in sublist]
num_classes = max(flat_labels) + 1
def multilable_to_onehot(label_str,num_classes):
    labels=list(map(int,label_str.split()))
    onehot= np.zeros(num_classes,dtype=np.float32)
    onehot[labels]=1
    return onehot
train_csv["onehot"]= train_csv['Labels'].apply(lambda x: multilable_to_onehot(x,num_classes))
train_csv

### Pre-processing

In [None]:
print("Number of Training data:",train_csv.shape[0])
class_counts=train_csv['Labels'].value_counts()
print(class_counts)
# classify different labels
small_classes = class_counts[class_counts < 0.01 * class_counts.sum()]
other_count = small_classes.sum()
class_counts = class_counts[class_counts >= 0.01 * class_counts.sum()]
class_counts['Classes < 1%'] = other_count

# Plot pie chart
colors = sns.color_palette("Set2", n_colors=len(class_counts))

# 画饼图
plt.figure(figsize=(8, 8))
plt.pie(
    class_counts.values,
    labels=class_counts.index,
    autopct='%1.1f%%',
    startangle=140,
    colors=colors,
    wedgeprops={'edgecolor': 'white'}
)
plt.title("Class Distribution")
plt.axis('equal')
plt.tight_layout()
plt.show()


#plot bar chart
label_counts=train_csv['onehot'].sum(axis=0)
plt.figure(figsize=(12,8))
bars=plt.bar(range(len(label_counts)),label_counts)
for i, count in enumerate(label_counts):
    plt.text(i,count+0.5, str(int(count)), ha='center', va='bottom')

plt.xlabel("Label Index")
plt.ylabel("Count")
plt.title("Label Distribution in Multi-label Dataset")
plt.xticks(range(len(label_counts))) 
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()

In [None]:
#we have 30k images, which is big enough, just do some kind of augmentation
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor()
])
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
class Ass2Dataset(Dataset):
    def __init__(self, df, image_dir, tokenizer, transform=None, max_length=128):
        self.df = df
        self.image_dir = image_dir
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = os.path.join(self.image_dir, row['ImageID'])
        try:
            image = Image.open(image_path).convert("RGB")
        except Exception as e:
            print(f"No Such file idx={idx}, error={e}")
            return self.__getitem__((idx + 1) % len(self))
          
        if self.transform:
            image = self.transform(image)
    
        # 处理文本
        text = row['Caption']
        try:
            encoded = self.tokenizer(
                text,
                padding='max_length',
                truncation=True,
                max_length=self.max_length,
                return_tensors="pt"
            )
        except Exception as e:
            print(f"No text idx={idx}, error={e}")
            return self.__getitem__((idx + 1) % len(self))
    
        input_ids = encoded['input_ids'].squeeze(0)
        attention_mask = encoded['attention_mask'].squeeze(0)
    
        labels = torch.tensor(row['onehot']).float()
    
        return {
            'image': image,
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'label': labels
        }

In [None]:
import torch.nn.functional as F
alpha=1.0/(label_counts+1e-6)
alpha=alpha/alpha.sum()
class BCEFocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, bce_weight=0.5):
        super(BCEFocalLoss, self).__init__()
        self.alpha = alpha  # Tensor of shape (num_classes,) or None
        self.gamma = gamma
        self.bce_weight = bce_weight
        self.bce = nn.BCEWithLogitsLoss(reduction='none')  # no reduction, we reduce manually

    def forward(self, inputs, targets):
        bce_loss = self.bce(inputs, targets)

        probas = torch.sigmoid(inputs)
        pt = torch.where(targets == 1, probas, 1 - probas)

        focal_loss = bce_loss * ((1 - pt) ** self.gamma)

        if self.alpha is not None:
            alpha = self.alpha.to(inputs.device)
            bce_loss = bce_loss * alpha
            focal_loss = focal_loss * alpha

        loss = self.bce_weight * bce_loss.mean() + (1 - self.bce_weight) * focal_loss.mean()
        return loss


## Model

In [None]:
class MultiModalClassifier(nn.Module):
    def __init__(self, num_labels=20, resnet_out=256, bert_out=256, dropout=0.3):
        super(MultiModalClassifier, self).__init__()

        # 使用 ResNet18
        self.resnet18 = resnet18(pretrained=True)
        self.resnet18.fc = nn.Linear(self.resnet18.fc.in_features, resnet_out)

        # 使用 Google 官方小 BERT
        self.bert = AutoModel.from_pretrained("google/bert_uncased_L-4_H-256_A-4")
        self.text_fc = nn.Linear(bert_out, bert_out)

        self.classifier = nn.Sequential(
            nn.Linear(resnet_out + bert_out, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, num_labels)
        )

    def forward(self, image, input_ids, attention_mask):
        image_features = self.resnet18(image)
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        text_features = bert_output.last_hidden_state[:, 0, :]
        text_features = self.text_fc(text_features)
        fused = torch.cat((image_features, text_features), dim=1)
        return self.classifier(fused)


In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}

        outputs = model(batch['image'], batch['input_ids'], batch['attention_mask'])
        loss = criterion(outputs, batch['label'])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)


def evaluate(model, dataloader, device, threshold=0.5):
    model.eval()
    all_preds, all_targets = [], []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch = {k: v.to(device) for k, v in batch.items()}

            outputs = model(batch['image'], batch['input_ids'], batch['attention_mask'])
            preds = torch.sigmoid(outputs) > threshold

            all_preds.append(preds.cpu().numpy())
            all_targets.append(batch['label'].cpu().numpy())

    all_preds = np.concatenate(all_preds)
    all_targets = np.concatenate(all_targets)

    micro_f1 = f1_score(all_targets, all_preds, average='micro', zero_division=0)
    macro_f1 = f1_score(all_targets, all_preds, average='macro', zero_division=0)
    return micro_f1, macro_f1


In [None]:
from torchvision.models import resnet50
def get_model_size(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    model_size_mb = total_params * 4 / (1024 ** 2)  # float32 = 4 bytes

    print(f"✅ 模型总参数量: {total_params:,}")
    print(f"✅ 可训练参数量: {trainable_params:,}")
    print(f"✅ 模型大小（内存占用，float32）: {model_size_mb:.2f} MB")

    return total_params, model_size_mb
def main():
    tokenizer = AutoTokenizer.from_pretrained("google/bert_uncased_L-4_H-256_A-4")
    train_df = train_csv.sample(frac=0.8, random_state=42)
    val_df = train_csv.drop(train_df.index)
    
    train_df = train_df.reset_index(drop=True)
    val_df = val_df.reset_index(drop=True)
    train_dataset = Ass2Dataset(train_df, os.path.join(dir_path,"data"), tokenizer, train_transform)
    val_dataset = Ass2Dataset(val_df, os.path.join(dir_path,"data"), tokenizer, val_transform)
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = MultiModalClassifier().to(device)
    get_model_size(model)
    # 根据标签频率初始化 alpha
    label_freq = torch.tensor(train_csv["onehot"].tolist()).sum(dim=0)
    alpha = 1.0 / (label_freq + 1e-6)
    alpha = alpha / alpha.sum()
    
    criterion = BCEFocalLoss(alpha=alpha, gamma=2, bce_weight=0.5)
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    
    for epoch in range(10):
        train_loss = train_one_epoch(model,train_loader,criterion,optimizer,device=device)
        micro_f1, macro_f1 = evaluate(model,val_loader,device=device)
    
        print(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f} | Micro-F1: {micro_f1:.4f} | Macro-F1: {macro_f1:.4f}")
if __name__ == '__main__':
    main()
