In [None]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, Dataset, random_split

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import cv2
import torchvision
from torchvision import transforms as T
from torchvision import datasets, models
from torchvision.datasets import ImageFolder

from torchmetrics.classification import MulticlassJaccardIndex

import sklearn
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score

import segmentation_models_pytorch as smp

import wandb

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# set seed for reproducibility
LUCKY_SEED = 42
torch.manual_seed(LUCKY_SEED)
torch.backends.cudnn.deterministic = True
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(LUCKY_SEED)
np.random.seed(LUCKY_SEED)

In [None]:
model = smp.Unet(
    encoder_name="efficientnet-b1",        
    encoder_weights="imagenet",                  
    classes=23
)
model.load_state_dict(torch.load('weights.pt')) # load weights from trained model
model.to(device)
model.eval()

In [None]:
transforms_test = A.Compose([
    A.Resize(352, 512),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

In [None]:
class CustomTestDataset(Dataset):
    def __init__(self, img_path, csv, transform = None):
        self.img_path = img_path
        self.csv = csv
        self.transform = transform
        
    def __len__(self):
        return len(self.csv)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_path, self.csv.iloc[idx, 0]) + '.jpg'
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        img_orig = img # orig image w/o normalization to display later on graphs
    
        if self.transform is not None:
            aug = self.transform(image = img)
            img = aug["image"]
            
        img = img.float()
        
        return img, img_orig 

In [None]:
img_path = "specify image folder path here"
csv = pd.read_csv("specify .csv with image names path here")
test_dataset = CustomTestDataset(img_path, csv, transform = transforms_test)
testloader = DataLoader(test_dataset, batch_size = 8, shuffle = False)

In [None]:
# display random image result
img, img_orig = test_dataset[0] 
img = img.to(device)
img = img.view(-1, 3, 352, 512) # need to include batch size
pred = model(img)
mask = torch.argmax(pred, dim = 1)

figure, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize = (30, 10))

ax1.imshow(img_orig.permute(1,2,0)) # change image shape from CxHxW to HxWxC
ax1.set_title("original image")

ax2.imshow(mask.cpu().squeeze()) #squeeze чтобы убрать dim по батчам
ax2.set_title("original mask")

ax3.imshow(masked.cpu().squeeze())
ax3.set_title("pred mask, mIoU score = " + str(score.item()))
