In [1]:
import torch
import torchvision.transforms as transforms
import os
from PIL import Image
# from resnet18 import init_model
import torchvision
from kan import KAN
from torch import nn
import pandas as pd

In [9]:
MODEL_WEIGHTS = "baseline_resnet50.pt"
TEST_IMAGES_DIR = "./data/train/cigs/"
SUBMISSION_PATH = "./data/submission2.csv"

In [10]:
class AdaptiveConcatPool2d(nn.Module):
    def __init__(self):
        super(AdaptiveConcatPool2d, self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.maxpool = nn.AdaptiveMaxPool2d(1)

    def forward(self, x):
        return torch.cat([self.avgpool(x), self.maxpool(x)], dim=1)

In [11]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# model = init_model(device, 2)
model = torchvision.models.resnet50(pretrained=False)
model.avgpool = AdaptiveConcatPool2d()
model.fc = nn.Sequential(
    nn.Dropout(p=0.5),
    KAN([model.fc.in_features * 2, 64, 2]))
model.load_state_dict(torch.load(MODEL_WEIGHTS, map_location=torch.device('cpu')))
model = model.to(device)
model.eval()



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [12]:
img_size = 224
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [13]:
all_image_names = os.listdir(TEST_IMAGES_DIR)
all_preds = []

for image_name in all_image_names:
    img_path = os.path.join(TEST_IMAGES_DIR, image_name)
    image = Image.open(img_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(image_tensor)
        pred = torch.sigmoid(output[:, 1]).item() >= 0.5
        all_preds.append(int(pred))

with open(SUBMISSION_PATH, "w") as f:
    f.write("image_name\tlabel_id\n")
    for name, cl_id in zip(all_image_names, all_preds):
        f.write(f"{name}\t{cl_id}\n")

In [14]:
data = pd.DataFrame(data=all_image_names, columns=["name"])

# data['label'] = y_true
data['pred'] = all_preds

In [15]:
data[data['pred'] == 0]

Unnamed: 0,name,pred
23,good_cigs_000389.jpg,0
84,good_cigs_000385.jpg,0
131,good_cigs_000099.jpg,0
132,good_cigs_000248.jpg,0
144,good_cigs_000302.jpg,0
148,good_cigs_000351.jpg,0
191,good_cigs_000133.jpg,0
204,good_cigs_000285.jpg,0
230,p3.jpg,0
262,good_cigs_000354.jpg,0
