In [1]:
import os
import torch
import numpy as np
# from torchsummary import summary ONLY works on pip!

# CODE FILES HERE
from model_params import get_model_data_dcgan
from models.dcgan import Dcgan, Generator, Discriminator
from models.resnet import ResNet, ResGenNet, ResDiscNet

from solver import Solver, Testing
from directories import Directories
from dataloader import DataLoader
from plots import plot_losses, plot_z_samples, plot_grid
from sampling import dcgan_sampling, generate_latent_points, interpolate_points, find_closest_gt
from contact_maps import get_contact_maps
import preprocessing

# SETTINGS HERE
os.environ['CUDA_LAUNCH_BLOCKING'] = "1" # to see the CUDA stack
%matplotlib inline
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
# supress cluttering warnings in solutions
import warnings
warnings.filterwarnings('ignore')

In [2]:
# setting device on GPU if available, else CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

# Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')

Using device: cpu



In [None]:
# Preprocessing from raw files to actual hd5 files
dataset_arg = "proteins"
prefix = "_max_length/"
data_root_folder = "../data/proteins/"
data = get_model_data_dcgan(dataset_arg)

residue_fragments = 256
preprocessing.MAX_SEQUENCE_LENGTH = 256

preprocessing.process_raw_data(False, force_pre_processing_overwrite=False, prefix=prefix,
                               data_root_folder=data_root_folder)

In [None]:
train_file_name = "training_50"
training_file = data_root_folder+ train_file_name + ".txt.hdf5"
print("training file: {0}".format(training_file))
padding = "pwd_pad"
print("padding: {}".format(padding))

In [None]:
directories = Directories("dcgan", dataset_arg.lower(), data["z_dim"], make_dirs=False)
data_loader = DataLoader(directories, data["batch_size"], dataset_arg.lower(),
                         training_file=training_file, residue_fragments=residue_fragments,
                         atom="calpha", padding=padding)

In [None]:
# show ground truth contact maps
samples = get_contact_maps(training_file, fragment_length=residue_fragments, padding=padding).unsqueeze(1)
print(samples.shape)
samples = samples[0:25]
plot_grid(samples, None, nrow=5, ncol=5)

In [None]:
# train dcgan model
#model = Dcgan(data_loader.input_dim, data["z_dim"])
#generator = Generator(data["z_dim"], res=residue_fragments)
#discriminator = Discriminator(1, 1, res=residue_fragments)

In [None]:
# train resnet model only for 256 fragments!
model = ResNet(data_loader.input_dim, data["z_dim"])
num_layers = 5
generator = ResGenNet(data["z_dim"], num_layers)
discriminator = ResDiscNet(num_layers)

In [None]:
# summary(generator, (100, 1, 1))
# summary(discriminator, (1, 256, 256))

In [None]:
solver = Solver(model, generator, discriminator, data["epochs"], data_loader, data["optimizer_G"],
                data["optimizer_D"], data["optim_config_G"], data["optim_config_D"],
                preprocessing.MAX_SEQUENCE_LENGTH, data["one_sided_labeling"], data["g_updates"],
                save_model_state=False)
solver.main()

In [3]:
# Insert name of model here if want to load a model
LOAD_MODEL = 1
if LOAD_MODEL:
    res_dir = "../results/"
    load_file = res_dir+"2000_64_17_850"
    solver = torch.load(load_file+".pt", map_location="cpu")
    generator = solver.generator
    discriminator = solver.discriminator
    solver.data_loader.directories.make_dirs = False
else:
    res_dir = solver.data_loader.directories.result_dir

In [None]:
# test complexity of model as in A.4
test_file_name = "test_proteins"
prefix = "_max_length/"
data_root_folder = "../data/proteins/"
PDB = 1
if not PDB:
    testing_file = data_root_folder+"preprocessed/" + str(preprocessing.MAX_SEQUENCE_LENGTH)\
                 + prefix + test_file_name + ".txt.hdf5"
else:
    testing_file = data_root_folder+"preprocessed/" + test_file_name +".pt"

    print("testing file: {0}".format(testing_file))
optim_config_G = {
    "lr": 1e-2,
    "weight_decay": 0.0,
    "betas": (0.5, 0.999)
}
optimizer_G = torch.optim.Adam(generator.parameters(), **optim_config_G)
test_loader = solver.data_loader.get_new_test_data_loader(testing_file=testing_file, batch_size=1, padding="no_pad",\
                                                          test_pdb=True)
testing = Testing(solver.generator, solver.model.z_dim, optimizer_G, test_loader, load_file+"_test_generator_dict.pt")
testing.test()

testing file: ../data/proteins/preprocessed/test_proteins.pt
Reading cache file in ../data/proteins/preprocessed/test_proteins_64_no_pad_contact_maps.dat
Testing complexity of the GAN, steps: 3000


In [None]:
# load inference generator
# generator = Generator(solver.model.z_dim, res=solver.data_loader.residue_fragments)
# generator.load_state_dict(torch.load(load_file+"_test_generator_dict.pt"))
generator.eval()
discriminator.eval()

In [None]:
# Plotting g and d losses for all epochs
plot_losses(solver, solver.train_loss_history["g_loss"], solver.train_loss_history["d_loss"])

In [None]:
# sampling a grid of distance maps
samples = dcgan_sampling(generator, solver.model.z_dim, 25).detach().numpy()
imgs, rows, cols = solver.get_sample_stats()
plot_grid(samples[:imgs], res_dir+"/plot_grid.png", nrow=rows, ncol=cols)

In [None]:
del imgs
del samples

In [None]:
# generate points in latent space
pts = generate_latent_points(solver.model.z_dim, 2)
# interpolate points in latent space
zs = interpolate_points(pts[0], pts[1])
samples = generator(zs).detach().numpy()
imgs, rows, cols = solver.get_sample_stats()
plot_grid(samples[:imgs], res_dir+"/plot_linear_interpolation.png", 1, 8)

In [None]:
del pts
del zs
del samples

In [None]:
# load the ground truth data set
max_len = 256
train_file_name = "training_95"
training_file = "../data/proteins/"+"preprocessed/" + str(max_len)\
                + "_max_length/" + train_file_name + ".txt.hdf5"
print(training_file)
test_loader = solver.data_loader.get_new_test_data_loader(testing_file=training_file, batch_size=1)
# find closest ground truth
min_loss_maps = find_closest_gt(generator, solver.model.z_dim, test_loader)

In [None]:
imgs, rows, cols = solver.get_sample_stats()
gens = torch.FloatTensor([x[2] for x in min_loss_maps[:imgs]])
gt_maps = [x[3] for x in min_loss_maps[:imgs]]
gt_maps = torch.cat(gt_maps).unsqueeze(1)
losses = [x[1] for x in min_loss_maps[:imgs]]
print("Losses: {}".format(losses))
l2_losses = [x[1] for x in min_loss_maps]
print("L2 map error {} {}".format(sum(l2_losses), sum(l2_losses)/len(l2_losses)))
asd = torch.cat([gens, gt_maps])
plot_grid(asd, res_dir+"/_plot_" + sum(l2_losses) + "_comparison.png")

In [None]:
del test_loader
del min_loss_maps
del gens
del gt_maps