In [11]:
import torch
import numpy as np
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
from torch.optim import Adam
from analysis.Stat_utils import get_rsp_data
from modeling.models.bethge import BethgeModel
from modeling.losses import corr_loss

In [12]:
class Mapper(torch.nn.Module):
    def __init__(self, numNeurons):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(numNeurons,512),
            nn.LeakyReLU(),
            nn.Linear(512,1024),
            nn.LeakyReLU(),
            nn.Linear(1024,512),
            nn.LeakyReLU(),
            nn.Linear(512,numNeurons),
        )

    def forward(self, x):
        output = self.layers(x)
        return output

In [13]:
channels = 256
num_layers = 9
input_size = 50
output_size = 299
first_k = 9
later_k = 3
pool_size = 2
factorized = True
num_maps = 1
net = BethgeModel(channels=channels, num_layers=num_layers, input_size= 50,
output_size=output_size, first_k=first_k, later_k=later_k,
                      input_channels=1, pool_size=pool_size, factorized=True,
                      num_maps=num_maps).cuda()
net.load_state_dict(torch.load('D:/school/research/CNN_Tang_project/saved_models/new_learned_models/'+'m2s1'+'_9_model_version_0'))
fake_rsp, real_rsp = get_rsp_data(net,'cuda','m2s1')

100%|██████████| 245/245 [00:33<00:00,  7.29it/s]


In [14]:
dataset = TensorDataset(torch.tensor(real_rsp), torch.tensor(fake_rsp))
loader = DataLoader(dataset, shuffle=True, batch_size=256)

In [16]:
num_epochs = 1000
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

mapper  = Mapper(299)
mapper = mapper.to(device)

optimizer = Adam(mapper.parameters(), lr=1e-3)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
best_loss = 300
for epoch in tqdm(range(num_epochs)):
    total_loss = 0
    for x, y in (loader):
        optimizer.zero_grad()
        x = x.to(device)
        y = y.to(device)
        out = mapper(x)
        loss = corr_loss(out, y, corr_portion=0.9, mae_portion=0.1)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    total_loss /= len(loader)
    #print(total_loss)
    if total_loss < best_loss:
        torch.save(mapper.state_dict(),'real_fake_mapper_corr')

100%|██████████| 1000/1000 [12:21<00:00,  1.35it/s]
