In [2]:
from argparse import ArgumentParser
import numpy as np
import pandas as pd
import torch
import os
import matplotlib.pyplot as plt
from Datamodule  import DataModule
from pytorch_lightning import Trainer
from load_model import load_model

predict=True
visualize=True
export=False
num_workers=8
results_dir='Results'
datapath='data'

# data params
train_res='128x128'
train_energy='all'
test_res='128x128'
test_energy='193'
max_samples=-1
batch_size=4
cached=False

# training params
criterion='sq_err'
lr=1e-3
amsgrad=True

# Model Params
model='MLP'
n_layers=4
h_dim=32
k_size=3
pc_err='1.80e-01'

saved=True

def visualize_target_output(pred, y):
    pred, y = pred.mean(dim=0), y.mean(dim=0)

    minmin = torch.min(torch.tensor(
        [pred.min().item(), y.min().item()]
        ))

    maxmax = torch.max(torch.tensor(
        [pred.max().item(), y.max().item()]
        ))

    fig, axs = plt.subplots(nrows=1, ncols=2, sharey=True, sharex=True)
    im = axs[0].imshow(pred, vmin=minmin, vmax=maxmax, cmap='bone')
    im = axs[1].imshow(y, vmin=minmin, vmax=maxmax, cmap='bone')

    axs[0].set_title('Prediction')
    axs[1].set_title('Target')
    fig.tight_layout()

    fig.subplots_adjust(right=0.85)
    cbar_ax = fig.add_axes([0.88, 0.15, 0.04, 0.7])
    fig.colorbar(im, cax=cbar_ax)
    plt.show()


In [3]:
trainer = Trainer(
    logger=False,
    accelerator='auto',
    devices='auto',
    )

dm_trained = DataModule(datapath=datapath,
                        cached=cached,
                        max_samples=max_samples,
                        batch_size=batch_size,
                        num_workers=num_workers,
                        res=test_res,
                        energy=test_energy,
                        stage='train')

dm_trained.prepare_data()
model, model_name = load_model(model, h_dim, n_layers, k_size, dm, saved, results_dir, criterion, lr, amsgrad, pc_err)
del dm_trained

dm_test = DataModule(args, stage='test')
trainer.test(model=model, datamodule=dm_test)

# for batch in predictions:
#     pred, target, fns = batch
#     for idx, file_nb in enumerate(fns):
#         visualize_target_output(pred[idx].detach(), target[idx].detach())

# pass

outfolder = os.path.join('Results', 'Predictions', model_name, args.test_res, args.test_energy)
os.makedirs(outfolder, exist_ok=True)
model.outfolder = outfolder

sample_file = os.path.join('data',args.test_res, args.test_energy,'0.dat')
f = open(sample_file, 'r')
model.header = f.readline()
f.close()

data_sample = np.loadtxt(sample_file)
model.x_values = np.sort(np.unique(data_sample[:, 0]))

trainer.predict(model, datamodule=dm_test)



GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


	 missing 5020 366.dat
	 missing 5020 126.dat
	 missing 5020 393.dat
	 missing 5020 278.dat
	 missing 5020 61.dat
	 missing 5020 220.dat
	 missing 5020 497.dat
	 missing 5020 422.dat
	 missing 5020 266.dat
	 missing 5020 42.dat
	 missing 5020 76.dat
	 missing 5020 285.dat
	 missing 5020 37.dat
	 missing 5020 304.dat
	 missing 5020 450.dat
	 missing 5020 353.dat
	 missing 5020 379.dat
	 missing 5020 44.dat
	 missing 5020 446.dat
	 missing 5020 321.dat
	 missing 5020 355.dat
	 missing 5020 401.dat
	 missing 5020 28.dat
	 missing 5020 166.dat
	 missing 5020 93.dat
	 missing 5020 177.dat
	 missing 5020 435.dat
	 missing 5020 214.dat
	 missing 5020 12.dat
	 missing 5020 77.dat
	 missing 5020 39.dat
	 missing 5020 9.dat
	 missing 5020 170.dat
	 missing 5020 410.dat
	 missing 5020 271.dat
	 missing 5020 344.dat
	 missing 5020 53.dat
	 missing 5020 338.dat
	 missing 5020 69.dat
	 missing 5020 449.dat
	 missing 5020 145.dat
	 missing 5020 361.dat
	 missing 5020 305.dat
	 missing 5020 113.dat
	 

KeyboardInterrupt: 