In [3]:
import torch
import torch.nn as nn
import numpy as np
from unet import UNet
from cnn import CNN
import matplotlib.pyplot as plt
import torchvision.ops as ops
import cv2 as cv
import os
import csv

In [4]:
torch.cuda.set_device(0)
device = torch.device('cuda')

In [6]:
unet = UNet(num_classes=1, in_channels=3, depth=5, merge_mode='concat')
unet.load_state_dict(torch.load("models/unet.pt", weights_only=True))
unet.to(device);
unet.eval();

In [7]:
images = []
directory = "test\\"
for file in os.listdir(directory):
    img = cv.imread(os.path.join(directory, file))
    img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
    img = cv.resize(img, (224, 224))
    images.append(img)
images = np.stack(images)/255

r = images[:, :, :, 0]
g = images[:, :, :, 1]
b = images[:, :, :, 2]

r = (r - 0.485)/(0.229)
g = (g - 0.456)/(0.224)
b = (b - 0.406)/(0.225)

images = np.stack([r, g, b], axis=3)

In [8]:
test_images = torch.tensor(images, dtype=torch.float32)

In [9]:
cropped_images = []
for i in range(test_images.shape[0]):
	m = torch.sigmoid(unet(test_images[i:i+1].permute(0, 3, 1, 2).to(device)).squeeze()).cpu().detach()
	m = torch.threshold(m, 0.1, 0)
	points = ops.masks_to_boxes(m.unsqueeze(0)).int().tolist()[0]
	img = test_images[i] * m.unsqueeze(2)
	img = img[points[1]:points[3], points[0]:points[2]]
	img = cv.resize(img.numpy(), (64, 64))
	cropped_images.append(img)
images = np.stack(cropped_images)
# np.save("cropped_test_set.npy", cropped_images)

In [10]:
cnn = CNN()
cnn.load_state_dict(torch.load("models/cnn.pt", weights_only=True))
cnn.to(device)
cnn.eval();

In [11]:
files = os.listdir("test")
test_set = torch.tensor(images, dtype=torch.float32)
predictions = []

In [12]:
for i in range(test_set.shape[0]):
	t = test_set[i:i+1].to(device)
	l = cnn(t.permute(0, 3, 1, 2))
	predictions.append(torch.argmax(torch.softmax(l, dim=1)).item()+1)
dictionary = []
for i in range(len(files)):
	dictionary.append([files[i], predictions[i]])
with open("submission.csv", mode='w', newline='') as file:
	writer = csv.writer(file)
	writer.writerows(dictionary)