In [1]:
import numpy as np
from tqdm import tqdm
import torch.nn as nn
from modeling.train_utils import array_to_dataloader
import torch
from torch.nn import functional as F

In [2]:
class reconstruct_CNN(nn.Module):
    def __init__(self, num_neuron):
        super().__init__()
        modules = []

        hidden_dims = [16, 64, 128, 64, 16]

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding= 1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
        )
        self.final_layer = nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[-1],
                                       hidden_dims[-1],
                                       kernel_size=5,
                                       stride=1,
                                       padding=2,
                                       output_padding=2,
                                       dilation=5),
                    nn.BatchNorm2d(hidden_dims[-1]),
                    nn.LeakyReLU(),
                    nn.Conv2d(hidden_dims[-1], out_channels= 1,
                              kernel_size= 3, padding= 1),
                    nn.Tanh())


        self.layers = nn.Sequential(*modules)
        self.linear_input = nn.Linear(num_neuron, hidden_dims[0] * 4)

    def forward(self, x):
        x = self.linear_input(x)
        x = x.view(-1, 16, 2, 2)
        x = self.layers(x)
        x = self.final_layer(x)
        return x

In [3]:
site = 'm1s1'
train_x = np.load('../data/Processed_Tang_data/all_sites_data_prepared/pics_data/train_img_'+site+'.npy')
val_x = np.load('../data/Processed_Tang_data/all_sites_data_prepared/pics_data/val_img_'+site+'.npy')
train_y = np.load('../data/Processed_Tang_data/all_sites_data_prepared/New_response_data/trainRsp_'+site+'.npy')
val_y = np.load('../data/Processed_Tang_data/all_sites_data_prepared/New_response_data/valRsp_'+site+'.npy')
model = reconstruct_CNN(302)
train_x = np.transpose(train_x, (0, 3, 1, 2))
val_x = np.transpose(val_x, (0, 3, 1, 2))
train_loader = array_to_dataloader(train_x, train_y, batch_size=1024, shuffle=True)
val_loader = array_to_dataloader(val_x, val_y, batch_size=1024)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)
criterion = F.mse_loss
device = 'cuda'

In [4]:
network = model.to(device)
losses = []
accs = []

bestloss = 200
num_epochs = 100
for e in tqdm(range(num_epochs)):

    train_losses = []
    network = network.train()
    for i, (x, y) in enumerate(train_loader):
        x = x.float().to(device)
        y = y.float().to(device)
        preds = network(y)
        loss = criterion(preds, x)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())
    losses.append(np.mean(train_losses))

    val_losses = []
    with torch.no_grad():
        network = network.eval()
        for i, (x, y) in enumerate(val_loader):
            x = x.float().to(device)
            y = y.float().to(device)
            preds = network(y)
            loss = criterion(preds, x)
            val_losses.append(loss.item())
    avg_loss = np.mean(val_losses)
    accs.append(avg_loss)
    if avg_loss < bestloss:
        torch.save(network.state_dict(), "direct_CNN")
        bestloss = avg_loss

    print(f'epoch {e} : train loss is {float(losses[-1])}')
    print(f'epoch {e} : val loss is   {float(accs[-1])}')

  1%|          | 1/100 [00:08<14:19,  8.68s/it]

epoch 0 : train loss is 0.04768678332295488
epoch 0 : val loss is   0.02894783206284046


  2%|▏         | 2/100 [00:14<11:51,  7.26s/it]

epoch 1 : train loss is 0.027093951919061297
epoch 1 : val loss is   0.028357338160276413


  3%|▎         | 3/100 [00:21<10:56,  6.77s/it]

epoch 2 : train loss is 0.024613526266287353
epoch 2 : val loss is   0.02641623094677925


  4%|▍         | 4/100 [00:27<10:30,  6.57s/it]

epoch 3 : train loss is 0.02356200955589028
epoch 3 : val loss is   0.02338859811425209


  5%|▌         | 5/100 [00:33<10:11,  6.43s/it]

epoch 4 : train loss is 0.022826017757110736
epoch 4 : val loss is   0.02241644449532032


  6%|▌         | 6/100 [00:39<09:57,  6.36s/it]

epoch 5 : train loss is 0.02239964316215585
epoch 5 : val loss is   0.022203637287020683


  7%|▋         | 7/100 [00:46<09:47,  6.31s/it]

epoch 6 : train loss is 0.02218365137848784
epoch 6 : val loss is   0.02077128179371357


  8%|▊         | 8/100 [00:52<09:41,  6.33s/it]

epoch 7 : train loss is 0.022076464148567003
epoch 7 : val loss is   0.02127241902053356


  9%|▉         | 9/100 [00:58<09:34,  6.32s/it]

epoch 8 : train loss is 0.021969860419631004
epoch 8 : val loss is   0.020950915291905403


 10%|█         | 10/100 [01:05<09:29,  6.32s/it]

epoch 9 : train loss is 0.021910426097319406
epoch 9 : val loss is   0.021001452580094337


 11%|█         | 11/100 [01:11<09:20,  6.30s/it]

epoch 10 : train loss is 0.02187724071828758
epoch 10 : val loss is   0.021061303094029427


 12%|█▏        | 12/100 [01:17<09:13,  6.29s/it]

epoch 11 : train loss is 0.02174416902091573
epoch 11 : val loss is   0.02043836936354637


 13%|█▎        | 13/100 [01:23<09:02,  6.24s/it]

epoch 12 : train loss is 0.021818591172204298
epoch 12 : val loss is   0.02043747529387474


 14%|█▍        | 14/100 [01:29<08:49,  6.16s/it]

epoch 13 : train loss is 0.021743569632663447
epoch 13 : val loss is   0.01990531198680401


 15%|█▌        | 15/100 [01:35<08:38,  6.10s/it]

epoch 14 : train loss is 0.0216620883301777
epoch 14 : val loss is   0.020153522491455078


 16%|█▌        | 16/100 [01:41<08:28,  6.06s/it]

epoch 15 : train loss is 0.021642353550037918
epoch 15 : val loss is   0.021199140697717667


 17%|█▋        | 17/100 [01:47<08:21,  6.04s/it]

epoch 16 : train loss is 0.0217143756830517
epoch 16 : val loss is   0.020386653020977974


 18%|█▊        | 18/100 [01:53<08:16,  6.06s/it]

epoch 17 : train loss is 0.02156508352388354
epoch 17 : val loss is   0.020682474598288536


 19%|█▉        | 19/100 [01:59<08:16,  6.13s/it]

epoch 18 : train loss is 0.021542893503518665
epoch 18 : val loss is   0.01977781020104885


 20%|██        | 20/100 [02:06<08:14,  6.19s/it]

epoch 19 : train loss is 0.021547339144436753
epoch 19 : val loss is   0.01996651664376259


 21%|██        | 21/100 [02:12<08:13,  6.25s/it]

epoch 20 : train loss is 0.021590327405754256
epoch 20 : val loss is   0.019856097176671028


 22%|██▏       | 22/100 [02:19<08:12,  6.32s/it]

epoch 21 : train loss is 0.02152296309085453
epoch 21 : val loss is   0.019743019714951515


 23%|██▎       | 23/100 [02:25<08:03,  6.28s/it]

epoch 22 : train loss is 0.02140875900273814
epoch 22 : val loss is   0.020257139578461647


 24%|██▍       | 24/100 [02:31<07:57,  6.28s/it]

epoch 23 : train loss is 0.02143524882986265
epoch 23 : val loss is   0.02023649960756302


 25%|██▌       | 25/100 [02:37<07:48,  6.25s/it]

epoch 24 : train loss is 0.021489230970687726
epoch 24 : val loss is   0.020027954131364822


 26%|██▌       | 26/100 [02:44<07:44,  6.27s/it]

epoch 25 : train loss is 0.02144248947939452
epoch 25 : val loss is   0.020053289830684662


 27%|██▋       | 27/100 [02:50<07:37,  6.27s/it]

epoch 26 : train loss is 0.021400870810098508
epoch 26 : val loss is   0.02023790031671524


 28%|██▊       | 28/100 [02:56<07:32,  6.29s/it]

epoch 27 : train loss is 0.021417353466591415
epoch 27 : val loss is   0.019972220063209534


 29%|██▉       | 29/100 [03:03<07:27,  6.31s/it]

epoch 28 : train loss is 0.021367466625045326
epoch 28 : val loss is   0.02010148949921131


 30%|███       | 30/100 [03:09<07:21,  6.31s/it]

epoch 29 : train loss is 0.021447934155516764
epoch 29 : val loss is   0.01979028433561325


 31%|███       | 31/100 [03:15<07:11,  6.25s/it]

epoch 30 : train loss is 0.021352099912131533
epoch 30 : val loss is   0.02005261741578579


 32%|███▏      | 32/100 [03:21<07:04,  6.25s/it]

epoch 31 : train loss is 0.021334341069793
epoch 31 : val loss is   0.01976514235138893


 33%|███▎      | 33/100 [03:27<06:54,  6.19s/it]

epoch 32 : train loss is 0.021243941531900096
epoch 32 : val loss is   0.020233705639839172


 34%|███▍      | 34/100 [03:33<06:45,  6.15s/it]

epoch 33 : train loss is 0.02130088023841381
epoch 33 : val loss is   0.02014862187206745


 35%|███▌      | 35/100 [03:39<06:37,  6.11s/it]

epoch 34 : train loss is 0.021345300326014265
epoch 34 : val loss is   0.020321663469076157


 36%|███▌      | 36/100 [03:45<06:27,  6.05s/it]

epoch 35 : train loss is 0.02137646051671575
epoch 35 : val loss is   0.020004134625196457


 37%|███▋      | 37/100 [03:51<06:20,  6.04s/it]

epoch 36 : train loss is 0.021247119094957325
epoch 36 : val loss is   0.020070545375347137


 38%|███▊      | 38/100 [03:58<06:18,  6.11s/it]

epoch 37 : train loss is 0.02133065741509199
epoch 37 : val loss is   0.020106354728341103


 39%|███▉      | 39/100 [04:04<06:17,  6.19s/it]

epoch 38 : train loss is 0.02125159021028701
epoch 38 : val loss is   0.01996791362762451


 40%|████      | 40/100 [04:10<06:14,  6.25s/it]

epoch 39 : train loss is 0.021194428479408518
epoch 39 : val loss is   0.02006358653306961


 41%|████      | 41/100 [04:16<06:06,  6.22s/it]

epoch 40 : train loss is 0.021248600226553047
epoch 40 : val loss is   0.019894519820809364


 42%|████▏     | 42/100 [04:23<06:00,  6.22s/it]

epoch 41 : train loss is 0.021269225405857843
epoch 41 : val loss is   0.02016736939549446


 43%|████▎     | 43/100 [04:29<05:54,  6.22s/it]

epoch 42 : train loss is 0.021182558499276638
epoch 42 : val loss is   0.02016105130314827


 44%|████▍     | 44/100 [04:35<05:47,  6.21s/it]

epoch 43 : train loss is 0.021197022596264586
epoch 43 : val loss is   0.020040377974510193


 45%|████▌     | 45/100 [04:41<05:41,  6.20s/it]

epoch 44 : train loss is 0.021144309957676074
epoch 44 : val loss is   0.020323365926742554


 46%|████▌     | 46/100 [04:47<05:34,  6.20s/it]

epoch 45 : train loss is 0.021288519013015664
epoch 45 : val loss is   0.019763750955462456


 47%|████▋     | 47/100 [04:54<05:29,  6.21s/it]

epoch 46 : train loss is 0.021153125966734746
epoch 46 : val loss is   0.02016441524028778


 48%|████▊     | 48/100 [05:00<05:23,  6.23s/it]

epoch 47 : train loss is 0.0211724180399495
epoch 47 : val loss is   0.02056507207453251


 49%|████▉     | 49/100 [05:06<05:17,  6.23s/it]

epoch 48 : train loss is 0.021132133627200827
epoch 48 : val loss is   0.02025441825389862


 50%|█████     | 50/100 [05:13<05:15,  6.31s/it]

epoch 49 : train loss is 0.021137956913341496
epoch 49 : val loss is   0.020810097455978394


 51%|█████     | 51/100 [05:19<05:10,  6.33s/it]

epoch 50 : train loss is 0.02116519393508925
epoch 50 : val loss is   0.020149491727352142


 52%|█████▏    | 52/100 [05:25<04:58,  6.22s/it]

epoch 51 : train loss is 0.021126560428563285
epoch 51 : val loss is   0.02027815580368042


 53%|█████▎    | 53/100 [05:31<04:47,  6.13s/it]

epoch 52 : train loss is 0.02111853429061525
epoch 52 : val loss is   0.01993698440492153


 54%|█████▍    | 54/100 [05:37<04:38,  6.06s/it]

epoch 53 : train loss is 0.021152833421878955
epoch 53 : val loss is   0.020637452602386475


 55%|█████▌    | 55/100 [05:43<04:30,  6.01s/it]

epoch 54 : train loss is 0.02109675295650959
epoch 54 : val loss is   0.02024885267019272


 56%|█████▌    | 56/100 [05:49<04:23,  5.98s/it]

epoch 55 : train loss is 0.021093760945779437
epoch 55 : val loss is   0.02031399868428707


 57%|█████▋    | 57/100 [05:55<04:21,  6.07s/it]

epoch 56 : train loss is 0.021078978281687286
epoch 56 : val loss is   0.019918683916330338


 58%|█████▊    | 58/100 [06:01<04:19,  6.18s/it]

epoch 57 : train loss is 0.02107065516140531
epoch 57 : val loss is   0.020009556785225868


 59%|█████▉    | 59/100 [06:08<04:13,  6.19s/it]

epoch 58 : train loss is 0.021089438568143284
epoch 58 : val loss is   0.020139280706644058


 60%|██████    | 60/100 [06:14<04:08,  6.22s/it]

epoch 59 : train loss is 0.021070946829722208
epoch 59 : val loss is   0.02075464092195034


 61%|██████    | 61/100 [06:20<04:02,  6.22s/it]

epoch 60 : train loss is 0.02103784422883216
epoch 60 : val loss is   0.020824236795306206


 62%|██████▏   | 62/100 [06:26<03:56,  6.22s/it]

epoch 61 : train loss is 0.021019016940365818
epoch 61 : val loss is   0.02019069716334343


 63%|██████▎   | 63/100 [06:33<03:50,  6.23s/it]

epoch 62 : train loss is 0.021048744866514906
epoch 62 : val loss is   0.020654944702982903


 64%|██████▍   | 64/100 [06:39<03:45,  6.26s/it]

epoch 63 : train loss is 0.02096313386059859
epoch 63 : val loss is   0.020358404144644737


 65%|██████▌   | 65/100 [06:45<03:40,  6.29s/it]

epoch 64 : train loss is 0.021040090578882134
epoch 64 : val loss is   0.019734041765332222


 66%|██████▌   | 66/100 [06:52<03:32,  6.25s/it]

epoch 65 : train loss is 0.021052421728039488
epoch 65 : val loss is   0.02004658244550228





KeyboardInterrupt: 

In [5]:
import matplotlib.pyplot as plt

In [6]:
sample = torch.tensor(val_y[:], dtype=torch.float).to(device)
#model.load_state_dict(torch.load('direct_CNN'))
model = model.to(device)
recon = model(sample).detach().cpu().numpy()
origin = val_x[:100]
for i, (r_img, img) in enumerate(zip(recon, origin)):
    r_img = np.reshape(r_img, (50, 50))
    img = np.reshape(img, (50, 50))
    print("newimg")
    plt.imsave(f'results/recon_{i}.png',r_img, cmap='gray')
    plt.show()
    plt.imsave(f'results/origin_{i}.png',img, cmap='gray')
    plt.show()

newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg


In [7]:
val_losses = []
accs = []
with torch.no_grad():
    network = model.eval()
    for i, (x, y) in enumerate(val_loader):
        x = x.float().to(device)
        y = y.float().to(device)
        preds = network(y)
        loss = criterion(preds, x)
        val_losses.append(loss.item())
avg_loss = np.mean(val_losses)
accs.append(avg_loss)

print(f'val loss is   {float(accs[-1])}')

val loss is   0.02003301866352558
