In [24]:
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from PIL import Image
import os
import copy
from torchvision.models.segmentation import deeplabv3_mobilenet_v3_large as deeplab

In [7]:
class myDS(Dataset):
    def __init__(self, fd):
        self.path_list = []
        for root, dirs, files in os.walk(fd):
            for file in files:
                if file.endswith('.png') & ("checkpoint" not in file):
                    image_path = os.path.join(root, file)
                    self.path_list.append(image_path)
        self.transforms = transforms.Compose([
                transforms.Resize(520),
                transforms.ToTensor(),  # Rescales to [0.0, 1.0]
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
    
    def __len__(self):
        return len(self.path_list)
    
    def __getitem__(self, idx):
        image_path = self.path_list[idx]
        image = Image.open(image_path).convert("RGB")
        
        return np.array(image), self.transforms(image)

In [26]:
model = deeplab(num_classes=4)
model.load_state_dict(torch.load("model_99.pth"))
model = model.cuda()

In [27]:
ds = myDS("/home/jupyter/ai_font/data/exp0717/train_whole")
dl = DataLoader(ds, batch_size=32, shuffle=True)
pred_transforms = transforms.Resize(128)

In [28]:
model.eval()
for j, data in enumerate(dl):
    imgs, x = data
    x = x.cuda()
    pred = model(x)
    mask = torch.argmax(pred_transforms(pred['out']),axis=1)
    for k, img in enumerate(imgs):
        for l in [1,2,3]:
            npimg = copy.deepcopy(img.detach().cpu().numpy())
            npimg[np.where(mask[k].detach().cpu().numpy()!=l)] = 255
            Image.fromarray(npimg).save(f"test_{j}_{k}_{l}.png")
    break

In [22]:
npimg

array([[[255, 255, 255],
        [255, 255, 255],
        [255, 255, 255],
        ...,
        [255, 255, 255],
        [255, 255, 255],
        [255, 255, 255]],

       [[255, 255, 255],
        [255, 255, 255],
        [255, 255, 255],
        ...,
        [255, 255, 255],
        [255, 255, 255],
        [255, 255, 255]],

       [[255, 255, 255],
        [255, 255, 255],
        [255, 255, 255],
        ...,
        [255, 255, 255],
        [255, 255, 255],
        [255, 255, 255]],

       ...,

       [[255, 255, 255],
        [255, 255, 255],
        [255, 255, 255],
        ...,
        [255, 255, 255],
        [255, 255, 255],
        [255, 255, 255]],

       [[255, 255, 255],
        [255, 255, 255],
        [255, 255, 255],
        ...,
        [255, 255, 255],
        [255, 255, 255],
        [255, 255, 255]],

       [[255, 255, 255],
        [255, 255, 255],
        [255, 255, 255],
        ...,
        [255, 255, 255],
        [255, 255, 255],
        [255, 255, 255]]