In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import transforms_enhance as T
from removal import RemoveTransform
from coco_dataset import CopyMoveDataset, SplicingDataset
from gen_patches import DresdenDataset
from model_ASPP import create_model, model_load_weights
import numpy as np

In [3]:
# removal dataset
rm_train_transform = RemoveTransform('/home/jayda960825/Documents/irregular_mask/disocclusion_img_mask/')
rm_train_dataset = DresdenDataset('/home/jayda960825/Documents/Dresden/dresden/train', 256, 256, transform=rm_train_transform)

# enhancement dataset
man_list = [T.Blur(),
            T.MorphOps(),
            T.Noise(),
            T.Quantize(),
            T.AutoContrast(),
            T.Equilize(),
            T.Compress()]
en_train_transform = T.Enhance(man_list, '/home/jayda960825/Documents/irregular_mask/disocclusion_img_mask/')
en_train_dataset = DresdenDataset('/home/jayda960825/Documents/Dresden/dresden/train', 256, 256, transform=en_train_transform)

json_path = "/home/jayda960825/coco/annotations/instances_train2017.json"
pic_path = "/home/jayda960825/coco/train2017"

# copy-move dataset
cp_train_dataset = CopyMoveDataset(json_path, pic_path, 256, 256)

# splicing dataset
sp_train_dataset = SplicingDataset(json_path, pic_path ,256, 256)

loading annotations into memory...
Done (t=11.56s)
creating index...
index created!
loading annotations into memory...
Done (t=12.65s)
creating index...
index created!


In [4]:
from torch.utils.data import DataLoader

rm_train_dataloader = DataLoader(rm_train_dataset, batch_size=2, shuffle=True)
en_train_dataloader = DataLoader(en_train_dataset, batch_size=2, shuffle=True)
cp_train_dataloader = DataLoader(cp_train_dataset, batch_size=2, shuffle=True)
sp_train_dataloader = DataLoader(sp_train_dataset, batch_size=2, shuffle=True)

In [5]:
def infinite_iter(dataloader):
    it = iter(dataloader)
    while True:
        try:
            ret = next(it)
            yield ret
        except StopIteration:
            it = iter(dataloader)

In [6]:
rm_train_iter = infinite_iter(rm_train_dataloader)
en_train_iter = infinite_iter(en_train_dataloader)
cp_train_iter = infinite_iter(cp_train_dataloader)
sp_train_iter = infinite_iter(sp_train_dataloader)

In [7]:
from val_dataset import VAL_Dataset
rm_val = VAL_Dataset('/home/jayda960825/Documents/Dresden/dresden/rm_val')

en_val = VAL_Dataset('/home/jayda960825/Documents/Dresden/dresden/en_val')

sp_val = VAL_Dataset('/home/jayda960825/Documents/Sliced_val_coco')

cp_val = VAL_Dataset('/home/jayda960825/Documents/Copymove_val_coco')

num_imgs = rm_val.__len__()

In [8]:
# import matplotlib.pyplot as plt

# img, mask = next(sp_train_iter)
# print(img.shape, mask.shape)

# plt.imshow(img[1].numpy().transpose(1, 2, 0))
# plt.show()
# plt.imshow(mask[1].numpy().transpose(1, 2, 0).squeeze(), cmap='gray')
# plt.show()

In [9]:
from tqdm import tqdm
def train(model, optim, num_iter, iters, criterion, epochs = 30, valid_loss_min = np.Inf):
    for epoch in range(epochs):
        model.train()
        for i in range(num_iter):
            rm_img, rm_masking = next(iters['rm'])
            en_img, en_masking = next(iters['en'])
            cp_img, cp_masking = next(iters['cp'])
            sp_img, sp_masking = next(iters['sp'])

            img = torch.cat([rm_img, en_img, cp_img, sp_img], dim=0)
            gt_masking = torch.cat([rm_masking, en_masking, cp_masking, sp_masking], dim=0)
            img = img.cuda()
            gt_masking = gt_masking.cuda()
            pred_masking = model(img)
            loss = criterion(pred_masking, gt_masking)

            optim.zero_grad()
            loss.backward()
            optim.step()
            print(i, loss.item())

        model.eval()
        valid_loss = 0.0
        for i in tqdm(range(num_imgs)):
            cp_img, cp_masking = cp_val.__getitem__(i)
            sp_img, sp_masking = sp_val.__getitem__(i)
            rm_img, rm_masking = rm_val.__getitem__(i)
            en_img, en_masking = en_val.__getitem__(i)
            
            img = torch.cat([rm_img.unsqueeze(0), en_img.unsqueeze(0), cp_img.unsqueeze(0), sp_img.unsqueeze(0)], dim=0)
            gt_masking = torch.cat([rm_masking.unsqueeze(0), en_masking.unsqueeze(0), cp_masking.unsqueeze(0), sp_masking.unsqueeze(0)], dim=0)
            img = img.cuda()
            gt_masking = gt_masking.cuda()
            with torch.no_grad():
                pred_masking = model(img)
            loss = criterion(pred_masking, gt_masking)
            
            valid_loss += loss.item()
            
        valid_loss = valid_loss/num_imgs
        if valid_loss <= valid_loss_min:
                print('validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(valid_loss_min, valid_loss))
                torch.save(model.state_dict(), 'mantra.pth')
                valid_loss_min = valid_loss

In [10]:
import torch
from torch import nn
mantranet = create_model(4, True)
mantranet = model_load_weights('/home/jayda960825/ManTraNet_2020/pretrained_weights/ManTraNet_Ptrain4.h5', mantranet)
optim = torch.optim.Adam(mantranet.parameters(), lr = 1e-04)
criterion = nn.BCELoss()
iters = {'rm': rm_train_iter,
         'en': en_train_iter,
         'cp': cp_train_iter,
         'sp': sp_train_iter,}
train(mantranet, optim, 10, iters, criterion)

INFO: freeze feature extraction part, trainable=False
0 0.82274329662323
1 0.8494170904159546
2 0.6435744166374207
3 0.5568759441375732
4 0.6096160411834717
5 0.45369720458984375
6 0.4626481533050537
7 0.3677751123905182
8 0.3390902876853943


  0%|          | 0/1691 [00:00<?, ?it/s]

9 0.3509153425693512


100%|██████████| 1691/1691 [07:09<00:00,  3.94it/s]


validation loss decreased (inf --> 0.492937).  Saving model ...
0 0.3464866578578949
1 0.3973553776741028
2 0.2885802388191223
3 0.3403049111366272
4 0.27498775720596313
5 0.2650032341480255
6 0.2416774481534958
7 0.29172876477241516
8 0.33813318610191345


  0%|          | 0/1691 [00:00<?, ?it/s]

9 0.22235846519470215


100%|██████████| 1691/1691 [07:21<00:00,  3.83it/s]


validation loss decreased (0.492937 --> 0.334739).  Saving model ...
0 0.26020199060440063
1 0.2650594115257263
2 0.28586041927337646
3 0.37279418110847473
4 0.4683917164802551
5 0.2579246163368225
6 0.25195205211639404
7 0.2217407077550888
8 0.24154554307460785


  0%|          | 0/1691 [00:00<?, ?it/s]

9 0.37168389558792114


 16%|█▌        | 271/1691 [01:09<06:05,  3.89it/s]


KeyboardInterrupt: 