In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import itertools
import os
import time
from datetime import datetime
import numpy as np
import torchvision.utils as vutils
import utils
import glob
import random
import torchvision
from torch.utils.data import Dataset
from PIL import Image

In [None]:
class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None, unaligned=False, mode='train'):
        self.transform = torchvision.transforms.Compose(transform)
        self.unaligned = unaligned
        self.train = (mode == 'train')
        fl_A = np.random.choice(glob.glob(os.path.join(root_dir, '%sA' % mode) + '/*.*'),8000)
        fl_B = np.random.choice(glob.glob(os.path.join(root_dir, '%sB' % mode) + '/*.*'),12000)

        self.files_A = sorted(fl_A)
        self.files_B = sorted(fl_B)
        #print(self.files_A)

    def __getitem__(self, index):
        item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))

        if self.unaligned:
            item_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)]))
        else:
            item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)]))

        if self.train:
            return {'trainA': item_A, 'trainB': item_B}
        else:
            return {'testA': item_A, 'testB': item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

In [None]:
import argparse
import sys
import torch.backends.cudnn as cudnn
import torchvision.datasets as dset
import torchvision.transforms as transforms
from models import *

FLAGS = None
device = torch.device("cuda" if True else "cpu")
img_size = 128
data_dir = 'dataset_cyclegan/'
dataset = 'mnist_noise_8_003'
channels = 1
num_blocks = 9
lr = 0.0002
print('Loading data...\n')

transform = [transforms.Resize(int(img_size*1.12), Image.BICUBIC),
                     transforms.RandomCrop((img_size, img_size)),
                     transforms.RandomHorizontalFlip(),
                     transforms.ToTensor(),
                     transforms.Normalize((0.5), (0.5))]
dataloader = DataLoader(ImageDataset(os.path.join(data_dir,dataset),
                                             transform=transform, unaligned=True, mode='train'),
                                batch_size=2, shuffle=True, num_workers=0)
test_dataloader = DataLoader(ImageDataset(os.path.join(data_dir, dataset),
                                                  transform=transform, unaligned=True, mode='test'),
                                     batch_size=20, shuffle=True, num_workers=0)



In [None]:

model = Model('cyclegan', device, dataloader, test_dataloader,channels, img_size, num_blocks)
model.load_from('weights/')
model.eval(batch_size=20)


In [None]:
transform = transforms.Compose([transforms.Resize(128),
                                transforms.ToTensor()
                                ])
images = transform(noised).unsqueeze(0).cuda()
images = next(iter(test_dataloader))['testA'].cuda()
predictions = model.generator_AB(images)
predictions = (predictions + 1)/2

In [None]:

from torchvision.utils import make_grid
import matplotlib.pyplot as plt
figsize=(10, 5)
plt.figure(figsize=figsize)
img_grid = make_grid(predictions[:12], nrow=4, padding=10,pad_value=1)
plt.imshow(np.transpose(img_grid.detach().cpu().numpy(), (1, 2, 0)),interpolation='nearest',cmap='gray')
plt.axis('off')
plt.show()

In [None]:
figsize=(10, 6)
plt.figure(figsize=figsize)
img_grid = make_grid(images[:12], nrow=4, padding=10,pad_value=1)
plt.imshow(np.transpose(img_grid.detach().cpu().numpy(), (1, 2, 0)),interpolation='nearest',cmap='gray')
plt.axis('off')
plt.show()

In [None]:
imgs_drawn = []
#transform = transforms.Resize((64,64))
def drawing_figure(image):
    gray = image 
    gray = cv2.resize(gray,(800,800))
    th, threshed = cv2.threshold(gray, 40, 255,cv2.THRESH_BINARY)
    cnts = cv2.findContours(threshed, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)[-2]
    sc =[]
    for ind,c in enumerate(cnts):
        area = cv2.contourArea(c)
        #print(area)
        if(area<25000): continue
        sc.append(c.copy())
    draw_img = np.zeros(gray.shape)
    for s in sc:
        for pt in s:
            draw_img[pt[0][1],pt[0][0]] = 255
    return draw_img, sc
for i in range(len(images)):
    noise_dotted_for_drawing = cv2.resize(images[i][0].cpu().numpy(), dsize=(800,800))
    t = predictions[i][0].detach().cpu().numpy()*255
    t = t.astype('uint8')
    contoured_img, hh = drawing_figure(t) 
    im2 = cv2.drawContours(noise_dotted_for_drawing, hh, -1, (255, 255, 0), 8)
    im2 = torch.tensor(im2).unsqueeze(0)
    imgs_drawn.append(im2)

figsize=(20, 10)
fig = plt.figure(figsize=figsize)
img_grid1 = make_grid(imgs_drawn, nrow=10, padding=0)
ax = plt.imshow(np.transpose(img_grid1.detach().cpu().numpy(), (1, 2, 0)),cmap='gray')
plt.axis('off')
