In [8]:
import os
import torch
import pandas as pd
from tqdm import tqdm
from PIL import Image
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder

# Show all rows in the output
pd.set_option('display.max_rows', None)
# Optional: Show all columns if needed
pd.set_option('display.max_columns', None)
# Optional: Expand full column width
pd.set_option('display.max_colwidth', None)
# Optional: Do not truncate the output horizontally
pd.set_option('display.expand_frame_repr', False)

# CONFIG
train_img_dir = "../../data_t1/train"
train_csv_path = "../../data_t1/train.csv"
test_img_dir = "../../data_t1"
test_csv_path =  "../../data_t1/test.csv"
model_path = 'resnet50_epoch20.pt'
model_name = "resnet50"

train_df = pd.read_csv(train_csv_path) 
test_df = pd.read_csv(test_csv_path) 

num_labels = len(train_df.label.unique())  # must match training

IMG_SIZE = 224
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


#  Your transform (same as during training)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

#  Load model
def load_trained_model(model_name, model_path, num_labels):
    if "resnet" in model_name:
        model = models.resnet50()
        model.fc = nn.Linear(model.fc.in_features, num_labels)
    elif "densenet" in model_name:
        model = models.densenet121()
        model.classifier = nn.Linear(model.classifier.in_features, num_labels)
    else:
        raise ValueError("Unsupported model type")
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    model.to(DEVICE)
    model.eval()
    return model

model = load_trained_model(model_name, model_path, num_labels)

#  Dataset class
class InferenceDataset(Dataset):
    def __init__(self, dataframe, img_dir, transform=None):
        self.data = dataframe
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        filename = self.data.iloc[idx]['filename']
        img_path = os.path.join(self.img_dir, filename)
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, filename


# Step 3: Create DataLoader
inference_dataset = InferenceDataset(test_df, test_img_dir, transform)
inference_loader = DataLoader(inference_dataset, batch_size=1, shuffle=False)

# Step 4: Inference
model.eval()
model.to(DEVICE)

results = []
with torch.no_grad():
    for images, filenames in tqdm(inference_loader, desc="🔍 Predicting"):
        images = images.to(DEVICE)
        outputs = model(images)
        probs = torch.softmax(outputs, dim=1)
        top3 = torch.topk(probs, k=3)
        preds = top3.indices[0].cpu().numpy()
        results.append([filenames[0]] + preds.tolist())

# Step 5: Format results into DataFrame
res_df = pd.DataFrame(results, columns=["filename", "label_1", "label_2", "label_3"])
# print(res_df)


🔍 Predicting:   0%|          | 117/113592 [00:04<1:15:14, 25.13it/s]


KeyboardInterrupt: 

In [None]:
res_df.to_csv("submissions_task1_epoch20.csv",index=False)