In [1]:
import torchvision
import torchvision.transforms as transforms
import torch
from modeling.models.bethge import BethgeModel
import torch.nn as nn
from tqdm import tqdm
import numpy as np
from modeling.train_utils import array_to_dataloader
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import conv2d
import matplotlib.pyplot as plt
device = 'cuda'
# TODO: fine tun with validation set?

CIFAR tests using shared core

In [25]:

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize(50),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
     transforms.Grayscale()]
)

batch_size = 100

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)


testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)



testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [2]:
channels = 256
num_layers = 9
input_size = 50

output_size = 299
first_k = 9
later_k = 3
pool_size = 2
factorized = True

num_maps = 1

net = BethgeModel(channels=channels, num_layers=num_layers, input_size=input_size,
                  output_size=output_size, first_k=first_k, later_k=later_k,
                  input_channels=1, pool_size=pool_size, factorized=True,
                  num_maps=num_maps).cuda()

net.to(device)
net.load_state_dict(torch.load('../saved_models/new_learned_models/m2s1_9_model_version_0'))
#net.load_state_dict(torch.load(f'../saved_models/cropped_models/m2s1_size_{input_size}_model'))

<All keys matched successfully>

In [6]:
class reconstruct_CNN(nn.Module):
    def __init__(self, num_neuron):
        super().__init__()
        modules = []

        hidden_dims = [16, 64, 128, 64, 16]

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding= 1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
        )
        self.final_layer = nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[-1],
                                       hidden_dims[-1],
                                       kernel_size=5,
                                       stride=1,
                                       padding=2,
                                       output_padding=2,
                                       dilation=5),
                    nn.BatchNorm2d(hidden_dims[-1]),
                    nn.LeakyReLU(),
                    nn.Conv2d(hidden_dims[-1], out_channels= 1,
                              kernel_size= 3, padding= 1),
                    nn.Tanh())


        self.layers = nn.Sequential(*modules)
        self.linear_input = nn.Linear(num_neuron, hidden_dims[0] * 4)

    def forward(self, x):
        x = self.linear_input(x)
        x = x.view(-1, 16, 2, 2)
        x = self.layers(x)
        x = self.final_layer(x)
        return x

In [4]:

criterion = nn.functional.mse_loss

network = net.to(device)
network = network.eval()
prednet = reconstruct_CNN(299).to(device)
optimizer = torch.optim.Adam(prednet.parameters(), lr=0.005)
losses = []
accs = []


In [29]:
bestloss = 200
num_epochs = 100
crop = transforms.CenterCrop(input_size)
for e in tqdm(range(num_epochs)):
    train_losses = []
    prednet = prednet.train()
    for i, (x, y) in enumerate(trainloader):
        x = x.float().to(device)
        y = y.float().to(device)
        rsp = network(crop(x))
        recon = prednet(rsp)
        loss = criterion(recon, x)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())
    losses.append(np.mean(train_losses))

    val_losses = []
    with torch.no_grad():
        prednet = prednet.eval()
        for i, (x, y) in enumerate(testloader):
            x = x.float().to(device)
            y = y.float().to(device)
            rsp = network(crop(x))
            recon = prednet(rsp)
            loss = criterion(recon, x)

            val_losses.append(loss.item())
    avg_loss = np.mean(val_losses)
    accs.append(avg_loss)
    if avg_loss < bestloss:
        torch.save(prednet.state_dict(), "sanity_check_m2s1_model")
        bestloss = avg_loss

    print(f'epoch {e} : train loss is {float(losses[-1])}')
    print(f'epoch {e} : val loss is   {float(accs[-1])}')

  1%|          | 1/100 [01:55<3:09:51, 115.06s/it]

epoch 0 : train loss is 0.14625312180817127
epoch 0 : val loss is   0.11408415392041206


  2%|▏         | 2/100 [03:50<3:07:59, 115.09s/it]

epoch 1 : train loss is 0.1092441799044609
epoch 1 : val loss is   0.1111518020182848


  3%|▎         | 3/100 [05:45<3:06:16, 115.22s/it]

epoch 2 : train loss is 0.1014069500118494
epoch 2 : val loss is   0.10031491592526436


  4%|▍         | 4/100 [07:38<3:02:39, 114.16s/it]

epoch 3 : train loss is 0.09681678412854672
epoch 3 : val loss is   0.09602990813553333


  5%|▌         | 5/100 [09:30<2:59:29, 113.37s/it]

epoch 4 : train loss is 0.09399545885622501
epoch 4 : val loss is   0.09459649935364724


  6%|▌         | 6/100 [11:22<2:56:55, 112.93s/it]

epoch 5 : train loss is 0.09179045873880386
epoch 5 : val loss is   0.0910287506878376


  7%|▋         | 7/100 [13:14<2:54:42, 112.71s/it]

epoch 6 : train loss is 0.0901765878200531
epoch 6 : val loss is   0.09155401557683945


  8%|▊         | 8/100 [15:06<2:52:42, 112.64s/it]

epoch 7 : train loss is 0.08864386230707169
epoch 7 : val loss is   0.08789640225470066


  8%|▊         | 8/100 [16:55<3:14:42, 126.98s/it]


KeyboardInterrupt: 

In [85]:

with torch.no_grad():
    val_losses = []
    prednet = prednet.eval()
    for i, (x, y) in enumerate(testloader):
        x = x.float().to(device)
        y = y.float().to(device)
        rsp = network(x)
        recon = prednet(rsp)
        loss = criterion(recon, x)

        val_losses.append(loss.item())
    avg_loss = np.mean(val_losses)
    print(avg_loss)

0.07913384726271033


In [17]:
import matplotlib.pyplot as plt
prednet.load_state_dict(torch.load('artificial_recon_model_cropped_m2s1_35'))
network.load_state_dict(torch.load('../saved_models/cropped_models/m2s1_size_35_model'))
network.eval()
prednet.eval()
for i in range(100):
    sample,_ = testset.__getitem__(i)
    sample = torch.reshape(sample.to(device), (1,1,50,50))
    recon = prednet(network(sample)).detach().cpu().numpy()
    origin = sample.detach().cpu().numpy()

    r_img = np.reshape(recon, (50, 50))
    img = np.reshape(origin, (50, 50))
    plt.imsave(f'recon_artificial_m2s1_35/recon_{i}.png', r_img, cmap='gray')
    plt.show()
    plt.imsave(f'recon_artificial_m2s1_35/origin_{i}.png', img, cmap='gray')
    plt.show()

Tang data reconstruction with shared core

In [11]:
site = 'm3s1'
train_x = np.load('../data/Processed_Tang_data/all_sites_data_prepared/pics_data/train_img_'+site+'.npy')
val_x = np.load('../data/Processed_Tang_data/all_sites_data_prepared/pics_data/val_img_'+site+'.npy')
train_y = np.load('../data/Processed_Tang_data/all_sites_data_prepared/New_response_data/trainRsp_'+site+'.npy')
val_y = np.load('../data/Processed_Tang_data/all_sites_data_prepared/New_response_data/valRsp_'+site+'.npy')
train_x = np.transpose(train_x, (0, 3, 1, 2))
val_x = np.transpose(val_x, (0, 3, 1, 2))
train_loader = array_to_dataloader(train_x, train_y, batch_size=1024, shuffle=True)
val_loader = array_to_dataloader(val_x, val_y, batch_size=1024)

In [None]:
prednet = reconstruct_CNN(324).to(device)
optimizer = torch.optim.Adam(prednet.parameters(), lr=0.005)
criterion = nn.MSELoss()
losses = []
accs = []
bestloss = 200
num_epochs = 100
for e in tqdm(range(num_epochs)):
    train_losses = []
    prednet = prednet.train()
    for i, (x, y) in enumerate(train_loader):
        x = x.float().to(device)
        y = y.float().to(device)
        #rsp = network(x)
        rsp = y
        recon = prednet(rsp)
        loss = criterion(recon, x)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        train_losses.append(loss.item())
    losses.append(np.mean(train_losses))

    val_losses = []
    with torch.no_grad():
        prednet = prednet.eval()
        for i, (x, y) in enumerate(val_loader):
            x = x.float().to(device)
            y = y.float().to(device)
            #rsp = network(x)
            rsp = y
            recon = prednet(rsp)
            loss = criterion(recon, x)

            val_losses.append(loss.item())
    avg_loss = np.mean(val_losses)
    accs.append(avg_loss)
    if avg_loss < bestloss:
        torch.save(prednet.state_dict(), "real_recon_model")
        bestloss = avg_loss
    torch.save(prednet.state_dict(), "real_recon_model_accumulative")

    print(f'epoch {e} : train loss is {float(losses[-1])}')
    print(f'epoch {e} : val loss is   {float(accs[-1])}')

  1%|          | 1/100 [00:07<11:44,  7.12s/it]

epoch 0 : train loss is 0.0562766996877534
epoch 0 : val loss is   0.03410699591040611


In [None]:
site = 'm2s1'
train_x = np.load('../data/Processed_Tang_data/all_sites_data_prepared/pics_data/train_img_'+site+'.npy')
val_x = np.load('../data/Processed_Tang_data/all_sites_data_prepared/pics_data/val_img_'+site+'.npy')
train_y = np.load('../data/Processed_Tang_data/all_sites_data_prepared/New_response_data/trainRsp_'+site+'.npy')
val_y = np.load('../data/Processed_Tang_data/all_sites_data_prepared/New_response_data/valRsp_'+site+'.npy')
sample = torch.tensor(val_x[:10], dtype=torch.float).to(device)
sample = torch.reshape(sample, (sample.shape[0],1,50,50))
prednet.load_state_dict(torch.load('artificial_recon_model_tang_data'))
prednet = prednet.to(device)
recon = prednet(network(sample)).detach().cpu().numpy()
origin = val_x[:10]
for i, (r_img, img) in enumerate(zip(recon, origin)):
    r_img = np.reshape(r_img, (50, 50))
    img = np.reshape(img, (50, 50))
    print("newimg")
    plt.imsave(f'recon_tang/recon_{i}.png',r_img, cmap='gray')
    plt.show()
    plt.imsave(f'recon_tang/origin_{i}.png',img, cmap='gray')
    plt.show()

Tang data reconstruction with sparse coding rsp

In [9]:
sparse_coding_dict = np.load("all_cell_dict_.npy", allow_pickle=True)[()]
sparse_coding_value = np.transpose(np.stack([sparse_coding_dict[x]['best_rsp_'] for x in range(299)]))
train_x_new = val_x[:900]
val_x_new = val_x[900:]
train_y_new = sparse_coding_value[:900]
val_y_new = sparse_coding_value[900:]
train_loader = array_to_dataloader(train_x_new, train_y_new, batch_size=10, shuffle=True)
val_loader = array_to_dataloader(val_x_new, val_y_new, batch_size=10)

In [None]:
prednet = reconstruct_CNN(299).to(device)
optimizer = torch.optim.Adam(prednet.parameters(), lr=0.005)
criterion = nn.functional.mse_loss
losses = []
accs = []
bestloss = 200
num_epochs = 100
for e in tqdm(range(num_epochs)):
    train_losses = []
    prednet = prednet.train()
    for i, (x, y) in enumerate(train_loader):
        x = x.float().to(device)
        y = y.float().to(device)
        recon = prednet(y)
        loss = criterion(recon, x)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())
    losses.append(np.mean(train_losses))

    val_losses = []
    with torch.no_grad():
        prednet = prednet.eval()
        for i, (x, y) in enumerate(val_loader):
            x = x.float().to(device)
            y = y.float().to(device)
            recon = prednet(y)
            loss = criterion(recon, x)

            val_losses.append(loss.item())
    avg_loss = np.mean(val_losses)
    accs.append(avg_loss)
    if avg_loss < bestloss:
        torch.save(prednet.state_dict(), "sparse_coding_recon_model_tang_data")
        bestloss = avg_loss

    print(f'epoch {e} : train loss is {float(losses[-1])}')
    print(f'epoch {e} : val loss is   {float(accs[-1])}')

In [None]:
site = 'm2s1'
sample = torch.tensor(val_y_new[:10], dtype=torch.float).to(device)
prednet.load_state_dict(torch.load('sparse_coding_recon_model_tang_data'))
prednet = prednet.to(device)
recon = prednet(sample).detach().cpu().numpy()
origin = val_x_new[:10]
for i, (r_img, img) in enumerate(zip(recon, origin)):
    r_img = np.reshape(r_img, (50, 50))
    img = np.reshape(img, (50, 50))
    print("newimg")
    plt.imsave(f'recon_sparse//recon_{i}.png',r_img, cmap='gray')
    plt.show()
    plt.imsave(f'recon_sparse/origin_{i}.png',img, cmap='gray')
    plt.show()

CIFAR data reconstruction with sparse coding template convolution

In [62]:
templates = np.load("Bruno_BASIS1_NUM_512_size16.npy")
templates = np.transpose(templates)
filter_size = np.round(np.sqrt(templates.shape[1])).__int__()
filter_num = templates.shape[0]
templates = np.reshape(templates, (filter_num,1, filter_size, filter_size))
templates = torch.tensor(templates).to(device)

In [55]:
def process_img_batch(imgs, filters):
    s = filters.shape[2]
    outer_size = (50-s)//2
    image_center = imgs[:,:, outer_size: s+outer_size, outer_size : s+outer_size]
    sparse_rsp = conv2d(image_center, filters)
    return torch.reshape(sparse_rsp, (len(sparse_rsp), filter_num))

In [None]:
iterion = nn.functional.mse_loss

prednet = reconstruct_CNN(templates.shape[0]).to(device)
optimizer = torch.optim.Adam(prednet.parameters(), lr=0.005)
losses = []
accs = []
bestloss = 200
num_epochs = 100
for e in tqdm(range(num_epochs)):
    train_losses = []
    prednet = prednet.train()
    for i, (x, y) in enumerate(trainloader):
        x = x.float().to(device)
        y = y.float().to(device)
        rsp = process_img_batch(x, templates)
        recon = prednet(rsp)
        loss = criterion(recon, x)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())
    losses.append(np.mean(train_losses))

    val_losses = []
    with torch.no_grad():
        prednet = prednet.eval()
        for i, (x, y) in enumerate(testloader):
            x = x.float().to(device)
            y = y.float().to(device)
            rsp = process_img_batch(x, templates)
            recon = prednet(rsp)
            loss = criterion(recon, x)

            val_losses.append(loss.item())
    avg_loss = np.mean(val_losses)
    accs.append(avg_loss)
    if avg_loss < bestloss:
        torch.save(prednet.state_dict(), "filter_recon_model_16")
        bestloss = avg_loss

    print(f'epoch {e} : train loss is {float(losses[-1])}')
    print(f'epoch {e} : val loss is   {float(accs[-1])}')

In [64]:
import matplotlib.pyplot as plt
prednet.load_state_dict(torch.load('filter_recon_model_16'))
prednet.train()
for i in range(100):
    sample,_ = testset.__getitem__(i)
    sample = torch.reshape(sample.to(device), (1,1,50,50))
    recon = prednet(process_img_batch(sample, templates)).detach().cpu().numpy()
    origin = sample.detach().cpu().numpy()

    r_img = np.reshape(recon, (50, 50))
    img = np.reshape(origin, (50, 50))
    plt.imsave(f'recon_sparse_convolve_16/recon_{i}.png', r_img, cmap='gray')
    plt.show()
    plt.imsave(f'recon_sparse_convolve_16/origin_{i}.png', img, cmap='gray')
    plt.show()

CIFAR data reconstruction with CNN learning the filters

In [None]:
templates = np.load("Bruno_BASIS1_NUM_512_size16.npy")
templates = np.transpose(templates)
filter_size = np.round(np.sqrt(templates.shape[1])).__int__()
filter_num = templates.shape[0]
templates = np.reshape(templates, (filter_num, 1, filter_size, filter_size))
templates = torch.tensor(templates).to(device)


def process_img_batch(imgs, filters):
    s = filters.shape[2]
    outer_size = (50 - s) // 2
    image_center = imgs[:, :, outer_size: s + outer_size, outer_size: s + outer_size]
    sparse_rsp = conv2d(image_center, filters)
    return torch.reshape(sparse_rsp, (len(sparse_rsp), filter_num))