In [1]:
import torchvision
import torchvision.transforms as transforms
import torch
from modeling.models.bethge import BethgeModel
import torch.nn as nn
from tqdm import tqdm
import numpy as np
from modeling.train_utils import array_to_dataloader
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import conv2d
import matplotlib.pyplot as plt
device = 'cuda'
# TODO: fine tun with validation set?

CIFAR tests using shared core

In [2]:

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize(50),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
     transforms.Grayscale()]
)

batch_size = 100

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)


testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)



testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [6]:
channels = 256
num_layers = 9
input_size = 50

output_size = 324
first_k = 9
later_k = 3
pool_size = 2
factorized = True

num_maps = 1

net = BethgeModel(channels=channels, num_layers=num_layers, input_size=input_size,
                  output_size=output_size, first_k=first_k, later_k=later_k,
                  input_channels=1, pool_size=pool_size, factorized=True,
                  num_maps=num_maps).cuda()

net.to(device)
#net.load_state_dict(torch.load('../saved_models/new_learned_models/m2s1_9_model_version_0'))
#net.load_state_dict(torch.load(f'../saved_models/cropped_models/m2s1_size_{input_size}_model'))

BethgeModel(
  (norm): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 256, kernel_size=(9, 9), stride=(1, 1), bias=False)
      (1): Softplus(beta=1, threshold=20)
      (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): Sequential(
      (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): Softplus(beta=1, threshold=20)
      (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): Sequential(
      (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): Softplus(beta=1, threshold=20)
      (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): Sequential(
      (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): Softplus(beta=1

In [4]:
class reconstruct_CNN(nn.Module):
    def __init__(self, num_neuron):
        super().__init__()
        modules = []

        hidden_dims = [16, 64, 128, 64, 16]

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding= 1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
        )
        self.final_layer = nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[-1],
                                       hidden_dims[-1],
                                       kernel_size=5,
                                       stride=1,
                                       padding=2,
                                       output_padding=2,
                                       dilation=5),
                    nn.BatchNorm2d(hidden_dims[-1]),
                    nn.LeakyReLU(),
                    nn.Conv2d(hidden_dims[-1], out_channels= 1,
                              kernel_size= 3, padding= 1),
                    nn.Tanh())


        self.layers = nn.Sequential(*modules)
        self.linear_input = nn.Linear(num_neuron, hidden_dims[0] * 4)

    def forward(self, x):
        x = self.linear_input(x)
        x = x.view(-1, 16, 2, 2)
        x = self.layers(x)
        x = self.final_layer(x)
        return x

In [7]:

criterion = nn.functional.mse_loss

network = net.to(device)
network = network.eval()
prednet = reconstruct_CNN(324).to(device)
optimizer = torch.optim.Adam(prednet.parameters(), lr=0.005)
losses = []
accs = []


In [29]:
bestloss = 200
num_epochs = 100
crop = transforms.CenterCrop(input_size)
for e in tqdm(range(num_epochs)):
    train_losses = []
    prednet = prednet.train()
    for i, (x, y) in enumerate(trainloader):
        x = x.float().to(device)
        y = y.float().to(device)
        rsp = network(crop(x))
        recon = prednet(rsp)
        loss = criterion(recon, x)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())
    losses.append(np.mean(train_losses))

    val_losses = []
    with torch.no_grad():
        prednet = prednet.eval()
        for i, (x, y) in enumerate(testloader):
            x = x.float().to(device)
            y = y.float().to(device)
            rsp = network(crop(x))
            recon = prednet(rsp)
            loss = criterion(recon, x)

            val_losses.append(loss.item())
    avg_loss = np.mean(val_losses)
    accs.append(avg_loss)
    if avg_loss < bestloss:
        torch.save(prednet.state_dict(), "sanity_check_m2s1_model")
        bestloss = avg_loss

    print(f'epoch {e} : train loss is {float(losses[-1])}')
    print(f'epoch {e} : val loss is   {float(accs[-1])}')

  1%|          | 1/100 [01:55<3:09:51, 115.06s/it]

epoch 0 : train loss is 0.14625312180817127
epoch 0 : val loss is   0.11408415392041206


  2%|▏         | 2/100 [03:50<3:07:59, 115.09s/it]

epoch 1 : train loss is 0.1092441799044609
epoch 1 : val loss is   0.1111518020182848


  3%|▎         | 3/100 [05:45<3:06:16, 115.22s/it]

epoch 2 : train loss is 0.1014069500118494
epoch 2 : val loss is   0.10031491592526436


  4%|▍         | 4/100 [07:38<3:02:39, 114.16s/it]

epoch 3 : train loss is 0.09681678412854672
epoch 3 : val loss is   0.09602990813553333


  5%|▌         | 5/100 [09:30<2:59:29, 113.37s/it]

epoch 4 : train loss is 0.09399545885622501
epoch 4 : val loss is   0.09459649935364724


  6%|▌         | 6/100 [11:22<2:56:55, 112.93s/it]

epoch 5 : train loss is 0.09179045873880386
epoch 5 : val loss is   0.0910287506878376


  7%|▋         | 7/100 [13:14<2:54:42, 112.71s/it]

epoch 6 : train loss is 0.0901765878200531
epoch 6 : val loss is   0.09155401557683945


  8%|▊         | 8/100 [15:06<2:52:42, 112.64s/it]

epoch 7 : train loss is 0.08864386230707169
epoch 7 : val loss is   0.08789640225470066


  8%|▊         | 8/100 [16:55<3:14:42, 126.98s/it]


KeyboardInterrupt: 

In [85]:

with torch.no_grad():
    val_losses = []
    prednet = prednet.eval()
    for i, (x, y) in enumerate(testloader):
        x = x.float().to(device)
        y = y.float().to(device)
        rsp = network(x)
        recon = prednet(rsp)
        loss = criterion(recon, x)

        val_losses.append(loss.item())
    avg_loss = np.mean(val_losses)
    print(avg_loss)

0.07913384726271033


In [17]:
import matplotlib.pyplot as plt
prednet.load_state_dict(torch.load('artificial_recon_model_cropped_m2s1_35'))
network.load_state_dict(torch.load('../saved_models/cropped_models/m2s1_size_35_model'))
network.eval()
prednet.eval()
for i in range(100):
    sample,_ = testset.__getitem__(i)
    sample = torch.reshape(sample.to(device), (1,1,50,50))
    recon = prednet(network(sample)).detach().cpu().numpy()
    origin = sample.detach().cpu().numpy()

    r_img = np.reshape(recon, (50, 50))
    img = np.reshape(origin, (50, 50))
    plt.imsave(f'recon_artificial_m2s1_35/recon_{i}.png', r_img, cmap='gray')
    plt.show()
    plt.imsave(f'recon_artificial_m2s1_35/origin_{i}.png', img, cmap='gray')
    plt.show()

Tang data reconstruction with shared core

In [20]:
site = 'm3s1'
train_x = np.load('../data/Processed_Tang_data/all_sites_data_prepared/pics_data/train_img_'+site+'.npy')
val_x = np.load('../data/Processed_Tang_data/all_sites_data_prepared/pics_data/val_img_'+site+'.npy')
train_y = np.load('../data/Processed_Tang_data/all_sites_data_prepared/New_response_data/trainRsp_'+site+'.npy')
val_y = np.load('../data/Processed_Tang_data/all_sites_data_prepared/New_response_data/valRsp_'+site+'.npy')
train_x = np.transpose(train_x, (0, 3, 1, 2))
val_x = np.transpose(val_x, (0, 3, 1, 2))
train_loader = array_to_dataloader(train_x, train_y, batch_size=20, shuffle=True)
val_loader = array_to_dataloader(val_x, val_y, batch_size=20)

In [None]:
prednet = reconstruct_CNN(324).to(device)
optimizer = torch.optim.Adam(prednet.parameters(), lr=0.005)
criterion = nn.MSELoss()
losses = []
accs = []
bestloss = 200
num_epochs = 100
for e in tqdm(range(num_epochs)):
    train_losses = []
    prednet = prednet.train()
    for i, (x, y) in enumerate(train_loader):
        x = x.float().to(device)
        y = y.float().to(device)
        #rsp = network(x)
        recon = prednet(y)
        loss = criterion(recon, x)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        train_losses.append(loss.item())
    losses.append(np.mean(train_losses))

    val_losses = []
    with torch.no_grad():
        prednet = prednet.eval()
        for i, (x, y) in enumerate(val_loader):
            x = x.float().to(device)
            y = y.float().to(device)
            #rsp = network(x)
            recon = prednet(y)
            loss = criterion(recon, x)

            val_losses.append(loss.item())
    avg_loss = np.mean(val_losses)
    accs.append(avg_loss)
    if avg_loss < bestloss:
        torch.save(prednet.state_dict(), "recon_model_tang_data_m3s1")
        bestloss = avg_loss
    torch.save(prednet.state_dict(), "recon_model_tang_data_m3s1_acc")

    print(f'epoch {e} : train loss is {float(losses[-1])}')
    print(f'epoch {e} : val loss is   {float(accs[-1])}')

In [7]:
site = 'm3s1'
train_x = np.load('../data/Processed_Tang_data/all_sites_data_prepared/pics_data/train_img_'+site+'.npy')
val_x = np.load('../data/Processed_Tang_data/all_sites_data_prepared/pics_data/val_img_'+site+'.npy')
train_y = np.load('../data/Processed_Tang_data/all_sites_data_prepared/New_response_data/trainRsp_'+site+'.npy')
val_y = np.load('../data/Processed_Tang_data/all_sites_data_prepared/New_response_data/valRsp_'+site+'.npy')
sample = torch.tensor(val_x[:10], dtype=torch.float).to(device)
sample = torch.reshape(sample, (sample.shape[0],1,50,50))
prednet.load_state_dict(torch.load('recon_model_tang_data_m3s1'))
prednet = prednet.to(device)
recon = prednet(network(sample)).detach().cpu().numpy()
origin = val_x[:10]
for i, (r_img, img) in enumerate(zip(recon, origin)):
    r_img = np.reshape(r_img, (50, 50))
    img = np.reshape(img, (50, 50))
    print("newimg")
    plt.imsave(f'recon_tang_m3s1/recon_{i}.png',r_img, cmap='gray')
    plt.show()
    plt.imsave(f'recon_tang_m3s1/origin_{i}.png',img, cmap='gray')
    plt.show()

newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg


In [27]:
class Mapper(torch.nn.Module):
    def __init__(self, numNeurons):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(numNeurons,512),
            nn.LeakyReLU(),
            nn.Linear(512,1024),
            nn.LeakyReLU(),
            nn.Linear(1024,512),
            nn.LeakyReLU(),
            nn.Linear(512,numNeurons),
        )

    def forward(self, x):
        output = self.layers(x)
        return output
network = Mapper(299).to(device)
network.load_state_dict(torch.load('real_fake_mapper_corr'))

<All keys matched successfully>

In [30]:
site = 'm2s1'
train_x = np.load('../data/Processed_Tang_data/all_sites_data_prepared/pics_data/train_img_' + site + '.npy')
val_x = np.load('../data/Processed_Tang_data/all_sites_data_prepared/pics_data/val_img_' + site + '.npy')
train_y = np.load('../data/Processed_Tang_data/all_sites_data_prepared/New_response_data/trainRsp_' + site + '.npy')
val_y = np.load('../data/Processed_Tang_data/all_sites_data_prepared/New_response_data/valRsp_' + site + '.npy')
train_x = np.transpose(train_x, (0, 3, 1, 2))
val_x = np.transpose(val_x, (0, 3, 1, 2))
train_loader = array_to_dataloader(train_x, train_y, batch_size=20, shuffle=True)
val_loader = array_to_dataloader(val_x, val_y, batch_size=20)
prednet = reconstruct_CNN(299).to(device)
optimizer = torch.optim.Adam(prednet.parameters(), lr=0.0005)
criterion = nn.MSELoss()
losses = []
accs = []
bestloss = 200
num_epochs = 100


In [31]:
for e in tqdm(range(num_epochs)):
    train_losses = []
    prednet = prednet.train()
    for i, (x, y) in enumerate(train_loader):
        x = x.float().to(device)
        y = y.float().to(device)
        rsp = network(y)
        recon = prednet(rsp)
        loss = criterion(recon, x)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        train_losses.append(loss.item())
    losses.append(np.mean(train_losses))

    val_losses = []
    with torch.no_grad():
        prednet = prednet.eval()
        for i, (x, y) in enumerate(val_loader):
            x = x.float().to(device)
            y = y.float().to(device)
            rsp = network(y)
            recon = prednet(rsp)
            loss = criterion(recon, x)

            val_losses.append(loss.item())
    avg_loss = np.mean(val_losses)
    accs.append(avg_loss)
    if avg_loss < bestloss:
        torch.save(prednet.state_dict(), "recon_model_tang_fakeReal_m2s1")
        bestloss = avg_loss
    torch.save(prednet.state_dict(), "recon_model_tang_fakeReal_m2s1_acc")

    print(f'epoch {e} : train loss is {float(losses[-1])}')
    print(f'epoch {e} : val loss is   {float(accs[-1])}')



  1%|          | 1/100 [00:19<31:24, 19.03s/it]

epoch 0 : train loss is 0.016884978087809012
epoch 0 : val loss is   0.197516151368618


  2%|▏         | 2/100 [00:37<30:55, 18.93s/it]

epoch 1 : train loss is 0.015148887807997514
epoch 1 : val loss is   0.17448889032006265


  3%|▎         | 3/100 [00:58<31:59, 19.79s/it]

epoch 2 : train loss is 0.014731904284610431
epoch 2 : val loss is   0.15641219809651374


  4%|▍         | 4/100 [01:17<31:06, 19.45s/it]

epoch 3 : train loss is 0.014413499657292755
epoch 3 : val loss is   0.159585816860199


  5%|▌         | 5/100 [01:36<30:30, 19.26s/it]

epoch 4 : train loss is 0.01415679029284083
epoch 4 : val loss is   0.16214995682239533


  6%|▌         | 6/100 [01:58<31:39, 20.21s/it]

epoch 5 : train loss is 0.01403614969163829
epoch 5 : val loss is   0.14965099781751634


  7%|▋         | 7/100 [02:18<31:07, 20.08s/it]

epoch 6 : train loss is 0.01392604076512614
epoch 6 : val loss is   0.1516859546303749


  8%|▊         | 8/100 [02:36<29:57, 19.54s/it]

epoch 7 : train loss is 0.013826183146055864
epoch 7 : val loss is   0.1531093680858612


  9%|▉         | 9/100 [02:55<29:10, 19.24s/it]

epoch 8 : train loss is 0.013735271670228364
epoch 8 : val loss is   0.14772795543074607


 10%|█         | 10/100 [03:13<28:23, 18.93s/it]

epoch 9 : train loss is 0.013686756702634145
epoch 9 : val loss is   0.15389274582266807


 11%|█         | 11/100 [03:31<27:40, 18.66s/it]

epoch 10 : train loss is 0.013615378696395427
epoch 10 : val loss is   0.14789145022630693


 12%|█▏        | 12/100 [03:49<27:03, 18.45s/it]

epoch 11 : train loss is 0.013575011684502265
epoch 11 : val loss is   0.15205459594726561


 13%|█▎        | 13/100 [04:08<26:46, 18.47s/it]

epoch 12 : train loss is 0.013519523629196444
epoch 12 : val loss is   0.15740030154585838


 14%|█▍        | 14/100 [04:26<26:23, 18.42s/it]

epoch 13 : train loss is 0.013459194891002713
epoch 13 : val loss is   0.14648803174495698


 15%|█▌        | 15/100 [04:44<26:05, 18.41s/it]

epoch 14 : train loss is 0.013411694703509613
epoch 14 : val loss is   0.15006860703229905


 16%|█▌        | 16/100 [05:03<25:41, 18.35s/it]

epoch 15 : train loss is 0.013369065739062368
epoch 15 : val loss is   0.14722641363739966


 17%|█▋        | 17/100 [05:21<25:19, 18.30s/it]

epoch 16 : train loss is 0.013323390930510905
epoch 16 : val loss is   0.152011236846447


 18%|█▊        | 18/100 [05:39<25:02, 18.32s/it]

epoch 17 : train loss is 0.013294619105148071
epoch 17 : val loss is   0.1569168108701706


 19%|█▉        | 19/100 [05:58<24:45, 18.34s/it]

epoch 18 : train loss is 0.013240778271336944
epoch 18 : val loss is   0.1534159579873085


 20%|██        | 20/100 [06:16<24:25, 18.32s/it]

epoch 19 : train loss is 0.013214543159785016
epoch 19 : val loss is   0.15533978268504142


 21%|██        | 21/100 [06:34<24:11, 18.38s/it]

epoch 20 : train loss is 0.013183379894768706
epoch 20 : val loss is   0.14783710479736328


 22%|██▏       | 22/100 [06:53<24:10, 18.59s/it]

epoch 21 : train loss is 0.01313333203027747
epoch 21 : val loss is   0.16577915132045745


 23%|██▎       | 23/100 [07:15<24:59, 19.48s/it]

epoch 22 : train loss is 0.013115834699639556
epoch 22 : val loss is   0.14456997364759444


 24%|██▍       | 24/100 [07:35<24:52, 19.63s/it]

epoch 23 : train loss is 0.013068322925543299
epoch 23 : val loss is   0.15192643851041793


 25%|██▌       | 25/100 [07:56<25:06, 20.08s/it]

epoch 24 : train loss is 0.013018881006128327
epoch 24 : val loss is   0.1521327744424343


 26%|██▌       | 26/100 [08:15<24:13, 19.65s/it]

epoch 25 : train loss is 0.01299526167786395
epoch 25 : val loss is   0.15006964445114135


 27%|██▋       | 27/100 [08:34<23:36, 19.40s/it]

epoch 26 : train loss is 0.012961988986032654
epoch 26 : val loss is   0.1544570817053318


 28%|██▊       | 28/100 [08:52<23:05, 19.24s/it]

epoch 27 : train loss is 0.012926981664281719
epoch 27 : val loss is   0.1423554600775242


 29%|██▉       | 29/100 [09:13<23:26, 19.81s/it]

epoch 28 : train loss is 0.012899566500884842
epoch 28 : val loss is   0.16367422550916672


 30%|███       | 30/100 [09:35<23:31, 20.17s/it]

epoch 29 : train loss is 0.012865662151111328
epoch 29 : val loss is   0.16267801716923713


 31%|███       | 31/100 [09:57<23:51, 20.75s/it]

epoch 30 : train loss is 0.012832391936027882
epoch 30 : val loss is   0.16117963656783105


 32%|███▏      | 32/100 [10:21<24:39, 21.75s/it]

epoch 31 : train loss is 0.012817681660989717
epoch 31 : val loss is   0.1513191194832325


 33%|███▎      | 33/100 [10:44<24:48, 22.22s/it]

epoch 32 : train loss is 0.012765937614425713
epoch 32 : val loss is   0.16833265930414198


 34%|███▍      | 34/100 [11:05<24:06, 21.92s/it]

epoch 33 : train loss is 0.012755642895947914
epoch 33 : val loss is   0.16973122850060463


 35%|███▌      | 35/100 [11:26<23:14, 21.46s/it]

epoch 34 : train loss is 0.012713113092076109
epoch 34 : val loss is   0.15271683514118195


 36%|███▌      | 36/100 [11:46<22:30, 21.10s/it]

epoch 35 : train loss is 0.012680667340565397
epoch 35 : val loss is   0.15491742685437201


 37%|███▋      | 37/100 [12:06<21:42, 20.67s/it]

epoch 36 : train loss is 0.01264244555286607
epoch 36 : val loss is   0.16026091679930687


 38%|███▊      | 38/100 [12:27<21:27, 20.77s/it]

epoch 37 : train loss is 0.012616850264864612
epoch 37 : val loss is   0.1553737363219261


 39%|███▉      | 39/100 [12:46<20:46, 20.43s/it]

epoch 38 : train loss is 0.012580802507166351
epoch 38 : val loss is   0.15760955840349197


 40%|████      | 40/100 [13:06<20:17, 20.29s/it]

epoch 39 : train loss is 0.012549446579449031
epoch 39 : val loss is   0.16611159458756447


 41%|████      | 41/100 [13:25<19:34, 19.90s/it]

epoch 40 : train loss is 0.01252844175228811
epoch 40 : val loss is   0.16232905745506288


 42%|████▏     | 42/100 [13:44<18:53, 19.55s/it]

epoch 41 : train loss is 0.012489162767115904
epoch 41 : val loss is   0.16786456003785133


 43%|████▎     | 43/100 [14:04<18:38, 19.63s/it]

epoch 42 : train loss is 0.012464137389419639
epoch 42 : val loss is   0.16575498521327972


 44%|████▍     | 44/100 [14:24<18:26, 19.77s/it]

epoch 43 : train loss is 0.012443528752583935
epoch 43 : val loss is   0.1683898787200451


 45%|████▌     | 45/100 [14:48<19:21, 21.12s/it]

epoch 44 : train loss is 0.012401575743962003
epoch 44 : val loss is   0.15121909618377685


 46%|████▌     | 46/100 [15:08<18:33, 20.62s/it]

epoch 45 : train loss is 0.012364584753974056
epoch 45 : val loss is   0.16372281193733215


 47%|████▋     | 47/100 [15:34<19:53, 22.51s/it]

epoch 46 : train loss is 0.0123494551368818
epoch 46 : val loss is   0.16742963179945947


 48%|████▊     | 48/100 [16:00<20:16, 23.39s/it]

epoch 47 : train loss is 0.012325476827670117
epoch 47 : val loss is   0.1651649197936058


 49%|████▉     | 49/100 [16:26<20:28, 24.09s/it]

epoch 48 : train loss is 0.012288669038525954
epoch 48 : val loss is   0.17010887801647187


 50%|█████     | 50/100 [16:52<20:32, 24.64s/it]

epoch 49 : train loss is 0.012264217626759593
epoch 49 : val loss is   0.16640994638204576


 51%|█████     | 51/100 [17:17<20:23, 24.96s/it]

epoch 50 : train loss is 0.012232765871271187
epoch 50 : val loss is   0.1768432405591011


 52%|█████▏    | 52/100 [17:37<18:45, 23.45s/it]

epoch 51 : train loss is 0.012192739034641763
epoch 51 : val loss is   0.16763269320130347


 53%|█████▎    | 53/100 [18:18<22:29, 28.70s/it]

epoch 52 : train loss is 0.012186537093806023
epoch 52 : val loss is   0.17037574902176858


 54%|█████▍    | 54/100 [18:38<19:52, 25.93s/it]

epoch 53 : train loss is 0.012145718615882251
epoch 53 : val loss is   0.16600133284926413


 55%|█████▌    | 55/100 [19:14<21:46, 29.04s/it]

epoch 54 : train loss is 0.012134793150447765
epoch 54 : val loss is   0.16669574201107026


 56%|█████▌    | 56/100 [19:34<19:22, 26.43s/it]

epoch 55 : train loss is 0.01209380707646511
epoch 55 : val loss is   0.17697576135396959


 57%|█████▋    | 57/100 [20:10<20:52, 29.14s/it]

epoch 56 : train loss is 0.01207573827451133
epoch 56 : val loss is   0.17279201105237008


 57%|█████▋    | 57/100 [20:30<15:28, 21.59s/it]


KeyboardInterrupt: 

In [32]:
site = 'm2s1'
train_x = np.load('../data/Processed_Tang_data/all_sites_data_prepared/pics_data/train_img_' + site + '.npy')
val_x = np.load('../data/Processed_Tang_data/all_sites_data_prepared/pics_data/val_img_' + site + '.npy')
train_y = np.load('../data/Processed_Tang_data/all_sites_data_prepared/New_response_data/trainRsp_' + site + '.npy')
val_y = np.load('../data/Processed_Tang_data/all_sites_data_prepared/New_response_data/valRsp_' + site + '.npy')
sample = torch.tensor(train_y[:20], dtype=torch.float).to(device)
prednet.load_state_dict(torch.load('recon_model_tang_fakeReal_m2s1'))
prednet = prednet.to(device)
recon = prednet(network(sample)).detach().cpu().numpy()
origin = train_x[:20]
for i, (r_img, img) in enumerate(zip(recon, origin)):
    r_img = np.reshape(r_img, (50, 50))
    img = np.reshape(img, (50, 50))
    print("newimg")
    plt.imsave(f'recon_tang_m2s1_fakeReal/recon_{i}.png', r_img, cmap='gray')
    plt.show()
    plt.imsave(f'recon_tang_m2s1_fakeReal/origin_{i}.png', img, cmap='gray')
    plt.show()

newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg
newimg


Tang data reconstruction with sparse coding rsp

In [9]:
sparse_coding_dict = np.load("all_cell_dict_.npy", allow_pickle=True)[()]
sparse_coding_value = np.transpose(np.stack([sparse_coding_dict[x]['best_rsp_'] for x in range(299)]))
train_x_new = val_x[:900]
val_x_new = val_x[900:]
train_y_new = sparse_coding_value[:900]
val_y_new = sparse_coding_value[900:]
train_loader = array_to_dataloader(train_x_new, train_y_new, batch_size=10, shuffle=True)
val_loader = array_to_dataloader(val_x_new, val_y_new, batch_size=10)

In [None]:
prednet = reconstruct_CNN(299).to(device)
optimizer = torch.optim.Adam(prednet.parameters(), lr=0.005)
criterion = nn.functional.mse_loss
losses = []
accs = []
bestloss = 200
num_epochs = 100
for e in tqdm(range(num_epochs)):
    train_losses = []
    prednet = prednet.train()
    for i, (x, y) in enumerate(train_loader):
        x = x.float().to(device)
        y = y.float().to(device)
        recon = prednet(y)
        loss = criterion(recon, x)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())
    losses.append(np.mean(train_losses))

    val_losses = []
    with torch.no_grad():
        prednet = prednet.eval()
        for i, (x, y) in enumerate(val_loader):
            x = x.float().to(device)
            y = y.float().to(device)
            recon = prednet(y)
            loss = criterion(recon, x)

            val_losses.append(loss.item())
    avg_loss = np.mean(val_losses)
    accs.append(avg_loss)
    if avg_loss < bestloss:
        torch.save(prednet.state_dict(), "sparse_coding_recon_model_tang_data")
        bestloss = avg_loss

    print(f'epoch {e} : train loss is {float(losses[-1])}')
    print(f'epoch {e} : val loss is   {float(accs[-1])}')

In [None]:
site = 'm2s1'
sample = torch.tensor(val_y_new[:10], dtype=torch.float).to(device)
prednet.load_state_dict(torch.load('sparse_coding_recon_model_tang_data'))
prednet = prednet.to(device)
recon = prednet(sample).detach().cpu().numpy()
origin = val_x_new[:10]
for i, (r_img, img) in enumerate(zip(recon, origin)):
    r_img = np.reshape(r_img, (50, 50))
    img = np.reshape(img, (50, 50))
    print("newimg")
    plt.imsave(f'recon_sparse//recon_{i}.png',r_img, cmap='gray')
    plt.show()
    plt.imsave(f'recon_sparse/origin_{i}.png',img, cmap='gray')
    plt.show()

CIFAR data reconstruction with sparse coding template convolution

In [62]:
templates = np.load("Bruno_BASIS1_NUM_512_size16.npy")
templates = np.transpose(templates)
filter_size = np.round(np.sqrt(templates.shape[1])).__int__()
filter_num = templates.shape[0]
templates = np.reshape(templates, (filter_num,1, filter_size, filter_size))
templates = torch.tensor(templates).to(device)

In [55]:
def process_img_batch(imgs, filters):
    s = filters.shape[2]
    outer_size = (50-s)//2
    image_center = imgs[:,:, outer_size: s+outer_size, outer_size : s+outer_size]
    sparse_rsp = conv2d(image_center, filters)
    return torch.reshape(sparse_rsp, (len(sparse_rsp), filter_num))

In [None]:
iterion = nn.functional.mse_loss

prednet = reconstruct_CNN(templates.shape[0]).to(device)
optimizer = torch.optim.Adam(prednet.parameters(), lr=0.005)
losses = []
accs = []
bestloss = 200
num_epochs = 100
for e in tqdm(range(num_epochs)):
    train_losses = []
    prednet = prednet.train()
    for i, (x, y) in enumerate(trainloader):
        x = x.float().to(device)
        y = y.float().to(device)
        rsp = process_img_batch(x, templates)
        recon = prednet(rsp)
        loss = criterion(recon, x)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())
    losses.append(np.mean(train_losses))

    val_losses = []
    with torch.no_grad():
        prednet = prednet.eval()
        for i, (x, y) in enumerate(testloader):
            x = x.float().to(device)
            y = y.float().to(device)
            rsp = process_img_batch(x, templates)
            recon = prednet(rsp)
            loss = criterion(recon, x)

            val_losses.append(loss.item())
    avg_loss = np.mean(val_losses)
    accs.append(avg_loss)
    if avg_loss < bestloss:
        torch.save(prednet.state_dict(), "filter_recon_model_16")
        bestloss = avg_loss

    print(f'epoch {e} : train loss is {float(losses[-1])}')
    print(f'epoch {e} : val loss is   {float(accs[-1])}')

In [64]:
import matplotlib.pyplot as plt
prednet.load_state_dict(torch.load('filter_recon_model_16'))
prednet.train()
for i in range(100):
    sample,_ = testset.__getitem__(i)
    sample = torch.reshape(sample.to(device), (1,1,50,50))
    recon = prednet(process_img_batch(sample, templates)).detach().cpu().numpy()
    origin = sample.detach().cpu().numpy()

    r_img = np.reshape(recon, (50, 50))
    img = np.reshape(origin, (50, 50))
    plt.imsave(f'recon_sparse_convolve_16/recon_{i}.png', r_img, cmap='gray')
    plt.show()
    plt.imsave(f'recon_sparse_convolve_16/origin_{i}.png', img, cmap='gray')
    plt.show()

CIFAR data reconstruction with CNN learning the filters

In [None]:
templates = np.load("Bruno_BASIS1_NUM_512_size16.npy")
templates = np.transpose(templates)
filter_size = np.round(np.sqrt(templates.shape[1])).__int__()
filter_num = templates.shape[0]
templates = np.reshape(templates, (filter_num, 1, filter_size, filter_size))
templates = torch.tensor(templates).to(device)


def process_img_batch(imgs, filters):
    s = filters.shape[2]
    outer_size = (50 - s) // 2
    image_center = imgs[:, :, outer_size: s + outer_size, outer_size: s + outer_size]
    sparse_rsp = conv2d(image_center, filters)
    return torch.reshape(sparse_rsp, (len(sparse_rsp), filter_num))