In [None]:
from torch import nn
import torch
from torch.autograd import Variable
import numpy as np
from methylnet.schedulers import *
from methylnet.plotter import *
from sklearn.preprocessing import LabelEncoder
from pymethylprocess.visualizations import umap_embed, plotly_plot
import copy
import inspect

In [None]:
RANDOM_SEED=42

np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

def train_vae(model, loader, loss_func, optimizer, cuda=True, epoch=0, kl_warm_up=0, beta=1.):
	model.train(True) 
	total_loss,total_recon_loss,total_kl_loss=0.,0.,0.
	stop_iter = loader.dataset.length // loader.batch_size
	total_loss,total_recon_loss,total_kl_loss=0.,0.,0.
	for i,(inputs, _) in enumerate(loader):
		if i == stop_iter:
			break
		inputs = Variable(inputs).view(inputs.size()[0],inputs.size()[1]) 
		if cuda:
			inputs = inputs.cuda()
		output, mean, logvar = model(inputs)
		loss, reconstruction_loss, kl_loss = vae_loss(output, inputs, mean, logvar, loss_func, epoch, kl_warm_up, beta)
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
		total_loss+=loss.item()
		total_recon_loss+=reconstruction_loss.item()
		total_kl_loss+=kl_loss.item()
	return model, total_loss,total_recon_loss,total_kl_loss

In [None]:
def val_vae(model, loader, loss_func, optimizer, cuda=True, epoch=0, kl_warm_up=0, beta=1.):
	model.eval() 
	stop_iter = loader.dataset.length // loader.batch_size
	total_loss,total_recon_loss,total_kl_loss=0.,0.,0.
	with torch.no_grad():
		for i,(inputs, _) in enumerate(loader):
			if i == stop_iter:
				break
			inputs = Variable(inputs).view(inputs.size()[0],inputs.size()[1]) 
			if cuda:
				inputs = inputs.cuda()
			output, mean, logvar = model(inputs)
			loss, reconstruction_loss, kl_loss = vae_loss(output, inputs, mean, logvar, loss_func, epoch, kl_warm_up, beta)
			total_loss+=loss.item()
			total_recon_loss+=reconstruction_loss.item()
			total_kl_loss+=kl_loss.item()
	return model, total_loss,total_recon_loss,total_kl_loss

In [None]:
def project_vae(model, loader, cuda=True):
	print(model)
	model.eval()
	final_outputs=[]
	with torch.no_grad():
		for inputs, outcomes in loader:
			inputs = Variable(inputs).view(inputs.size()[0],inputs.size()[1]) # modify for convolutions, add batchnorm2d?
			if cuda:
				inputs = inputs.cuda()
			z = np.squeeze(model.get_latent_z(inputs).detach().cpu().numpy())
			final_outputs.append(z)
		z=np.vstack(final_outputs)
	return z, None, None

In [None]:

class AutoEncoder:
	def __init__(self, autoencoder_model, n_epochs, loss_fn, optimizer, cuda=True, kl_warm_up=0, beta=1.,scheduler_opts={}):
		self.model=autoencoder_model
		if cuda:
			self.model = self.model.cuda()
		self.n_epochs = n_epochs
		self.loss_fn = loss_fn
		self.optimizer = optimizer
		self.cuda = cuda
		self.kl_warm_up = kl_warm_up
		self.beta=beta
		self.scheduler = Scheduler(self.optimizer,scheduler_opts) if scheduler_opts else Scheduler(self.optimizer)
		self.vae_animation_fname='animation.mp4'
		self.loss_plt_fname='loss.png'
		self.plot_interval=5
		self.embed_interval=200
		self.validation_set = False

	def fit(self, train_data):
		loss_list = []
		model = self.model
		best_model=copy.deepcopy(self.model)
		animation_plts=[]
		plt_data={'kl_loss':[],'recon_loss':[],'lr':[],'val_kl_loss':[],'val_recon_loss':[], 'val_loss':[]}
		for epoch in range(self.n_epochs):
			model, loss, recon_loss, kl_loss = train_vae(model, train_data, self.loss_fn, self.optimizer, self.cuda, epoch, self.kl_warm_up, self.beta)
			self.scheduler.step()
			plt_data['kl_loss'].append(kl_loss)
			plt_data['recon_loss'].append(recon_loss)
			plt_data['lr'].append(self.scheduler.get_lr())
			print("Epoch {}: Loss {}, Recon Loss {}, KL-Loss {}".format(epoch,loss,recon_loss,kl_loss))
			if self.validation_set:
				model, val_loss, val_recon_loss, val_kl_loss = val_vae(model, self.validation_set, self.loss_fn, self.optimizer, self.cuda, epoch, self.kl_warm_up, self.beta)
				plt_data['val_kl_loss'].append(val_kl_loss)
				plt_data['val_recon_loss'].append(val_recon_loss)
				plt_data['val_loss'].append(val_loss)
				print("Epoch {}: Val-Loss {}, Val-Recon Loss {}, Val-KL-Loss {}".format(epoch,val_loss,val_recon_loss,val_kl_loss))
			if epoch >= self.kl_warm_up:
				loss = loss if not self.validation_set else val_loss
				loss_list.append(loss)
				if loss <= min(loss_list): # next get models for lowest reconstruction and kl, 3 models
					best_model=copy.deepcopy(model)#.state_dict())
					best_epoch=epoch
				if 0 and epoch % self.embed_interval == 0:
					z,samples,outcomes=project_vae(best_model, train_data if not self.validation_set else self.validation_set, self.cuda)
					beta_df=pd.DataFrame(z,index=samples)
					plotly_plot(umap_embed(beta_df,outcomes,n_neighbors=8,supervised=False,min_dist=0.2,metric='euclidean'),'training_{}.html'.format(best_epoch))
			if 0 and self.plot_interval and epoch % self.plot_interval == 0:
				z,_,outcomes=project_vae(model, train_data, self.cuda)
				animation_plts.append(Plot('Latent Embedding, epoch {}'.format(epoch),
						data=PlotTransformer(z,LabelEncoder().fit_transform(outcomes)).transform()))
		if 0:
			plts=Plotter([Plot(k,'epoch','lr' if 'loss' not in k else k,
						  pd.DataFrame(np.vstack((range(len(plt_data[k])),plt_data[k])).T,
									   columns=['x','y'])) for k in plt_data if plt_data[k]],animation=False)
			plts.write_plots(self.loss_plt_fname)
		if 0 and self.plot_interval:
			Plotter(animation_plts).write_plots(self.vae_animation_fname)
		self.min_loss = min(np.array(plt_data['kl_loss'])+np.array(plt_data['recon_loss']))
		if self.validation_set:
			self.min_val_loss = plt_data['val_loss'][best_epoch]
			self.min_val_kl_loss = plt_data['val_kl_loss'][best_epoch]
			self.min_val_recon_loss = plt_data['val_recon_loss'][best_epoch]
		else:
			self.min_val_loss, self.min_val_kl_loss, self.min_val_recon_loss  = -1.,-1.,-1.
		self.best_epoch = best_epoch
		self.model = best_model#self.model.load_state_dict(best_model)
		self.training_plot_data = plt_data
		return self

	def add_validation_set(self, validation_data):
		self.validation_set=validation_data

	def transform(self, train_data):
		return project_vae(self.model, train_data, self.cuda)

	def fit_transform(self, train_data):
		return self.fit(train_data).transform(train_data)

	def generate(self, train_data):
		self.model.eval()
		with torch.no_grad():
			X_hat=[]
			for i,(X,_) in enumerate(train_data):
				if self.cuda:
					X=X.cuda()
				X_hat.append(self.model(X)[0].detach().cpu())
			X_hat=torch.cat(X_hat,0).numpy()
		return X_hat

def vae_loss(output, input, mean, logvar, loss_func, epoch, kl_warm_up=0, beta=1.):
	if type(output) != type([]):
		output = [output]
	recon_loss = sum([loss_func(out, input) for out in output])
	kl_loss = torch.mean(0.5 * torch.sum(
		torch.exp(logvar) + mean**2 - 1. - logvar, 1))
	kl_loss *= beta
	if epoch < kl_warm_up:
		kl_loss *= np.clip(epoch/kl_warm_up,0.,1.)
	#print(recon_loss,kl_loss)
	return recon_loss + kl_loss, recon_loss, kl_loss

In [None]:
class TybaltTitusVAE(nn.Module):
	def __init__(self, n_input, n_latent, hidden_layer_encoder_topology=[100,100,100], cuda=False):
		super(TybaltTitusVAE, self).__init__()
		self.n_input = n_input
		self.n_latent = n_latent
		self.cuda_on = cuda
		self.pre_latent_topology = [n_input]+(hidden_layer_encoder_topology if hidden_layer_encoder_topology else [])
		self.post_latent_topology = [n_latent]+(hidden_layer_encoder_topology[::-1] if hidden_layer_encoder_topology else [])
		self.encoder_layers = []
		if len(self.pre_latent_topology)>1:
			for i in range(len(self.pre_latent_topology)-1):
				layer = nn.Linear(self.pre_latent_topology[i],self.pre_latent_topology[i+1])
				torch.nn.init.xavier_uniform_(layer.weight)
				self.encoder_layers.append(nn.Sequential(layer,nn.ReLU()))
		self.encoder = nn.Sequential(*self.encoder_layers) if self.encoder_layers else nn.Dropout(p=0.)
		self.z_mean = nn.Sequential(nn.Linear(self.pre_latent_topology[-1],n_latent),nn.BatchNorm1d(n_latent))
		self.z_var = nn.Sequential(nn.Linear(self.pre_latent_topology[-1],n_latent),nn.BatchNorm1d(n_latent))
		self.z_develop = nn.Linear(n_latent,self.pre_latent_topology[-1])
		self.decoder_layers = []
		if len(self.post_latent_topology)>1:
			for i in range(len(self.post_latent_topology)-1):
				layer = nn.Linear(self.post_latent_topology[i],self.post_latent_topology[i+1])
				torch.nn.init.xavier_uniform_(layer.weight)
				self.decoder_layers.append(nn.Sequential(layer,nn.ReLU()))
		self.decoder_layers = nn.Sequential(*self.decoder_layers)
		self.output_layer = nn.Sequential(nn.Linear(self.post_latent_topology[-1],n_input),nn.Sigmoid())
		if self.decoder_layers:
			self.decoder = nn.Sequential(*[self.decoder_layers,self.output_layer])
		else:
			self.decoder = self.output_layer

	def sample_z(self, mean, logvar):
		stddev = torch.exp(0.5 * logvar)
		noise = Variable(torch.randn(stddev.size()))
		if self.cuda_on:
			noise=noise.cuda()
		if not self.training:
			noise = 0.
			stddev = 0.
		return (noise * stddev) + mean

	def encode(self, x):
		x = self.encoder(x)
		mean = self.z_mean(x)
		var = self.z_var(x)
		return mean, var

	def decode(self, z):
		out = self.decoder(z)
		return out

	def forward(self, x):
		mean, logvar = self.encode(x)
		z = self.sample_z(mean, logvar)
		out = self.decode(z)
		return out, mean, logvar

	def get_latent_z(self, x):
		mean, logvar = self.encode(x)
		return self.sample_z(mean, logvar)

	def forward_predict(self, x):
		return self.get_latent_z(x)

def train_mlp(model, loader, loss_func, optimizer_vae, optimizer_mlp, cuda=True, categorical=False, train_decoder=False):
	model.train(True)

	#model.vae.eval() also freeze for depth of tuning?
	#print(loss_func)
	stop_iter = loader.dataset.length // loader.batch_size
	running_loss=0.
	running_decoder_loss=0.
	for i,(inputs, y_true) in enumerate(loader): # change dataloder for classification/regression tasks
		#print(samples)
		if inputs.size()[0] == 1 and i == stop_iter:
			break
		inputs = Variable(inputs).view(inputs.size()[0],inputs.size()[1])
		y_true = Variable(y_true)
		if categorical:
			y_true=y_true.argmax(1).long()
		#print(inputs.size())
		if cuda:
			inputs = inputs.cuda()
			y_true = y_true.cuda()
		y_predict, z = model(inputs)
		loss = loss_func(y_predict,y_true)

		optimizer_vae.zero_grad()
		optimizer_mlp.zero_grad()
		loss.backward()
		if train_decoder:
			model, decoder_loss = train_decoder_(model, inputs, z)
			running_decoder_loss += decoder_loss
		optimizer_vae.step()
		optimizer_mlp.step()
		running_loss+=loss.item()
	if train_decoder:
		print('Decoder Loss is {}'.format(running_decoder_loss))
	return model, running_loss

def val_mlp(model, loader, loss_func, cuda=True, categorical=False, train_decoder=False):
	model.eval()

	#model.vae.eval() also freeze for depth of tuning?
	stop_iter = loader.dataset.length // loader.batch_size
	running_decoder_loss=0.
	running_loss=0.
	with torch.no_grad():
		for i,(inputs, y_true) in enumerate(loader): # change dataloder for classification/regression tasks
			if inputs.size()[0] == 1 and i == stop_iter:
				break
			inputs = Variable(inputs).view(inputs.size()[0],inputs.size()[1])
			y_true = Variable(y_true)
			#print(inputs.size())
			if categorical:
				y_true=y_true.argmax(1).long()
			if cuda:
				inputs = inputs.cuda()
				y_true = y_true.cuda()
			y_predict, z = model(inputs)
			loss = loss_func(y_predict,y_true)
			running_loss+=loss.item()
			if train_decoder:
				running_decoder_loss += val_decoder_(model, inputs, z)
		if train_decoder:
			print('Val Decoder Loss is {}'.format(running_decoder_loss))
	return model, running_loss

def test_mlp(model, loader, categorical, cuda=True, output_latent=True):

	model.eval()
	Y_pred=[]
	final_latent=[]
	Y_true=[]
	with torch.no_grad():
		for inputs, y_true in loader: # change dataloder for classification/regression tasks
			print(inputs)
			inputs = Variable(inputs).view(inputs.size()[0],inputs.size()[1])
			y_true = Variable(y_true)
			#print(inputs.size())
			if cuda:
				inputs = inputs.cuda()
				y_true = y_true.cuda()
			y_predict, z = model(inputs)
			y_predict=np.squeeze(y_predict.detach().cpu().numpy())
			y_true=np.squeeze(y_true.detach().cpu().numpy())
			#print(y_predict.shape,y_true.shape)
			#print(y_predict,y_true)
			if len(y_predict.shape) < 2:
				y_predict=y_predict.flatten()
			if len(y_true.shape) < 2:
				y_true=y_true.flatten()  # FIXME
			Y_pred.append(y_predict)
			final_latent.append(np.squeeze(z.detach().cpu().numpy()))
			Y_true.append(y_true)
	if len(Y_pred) > 1:
		if all(list(map(lambda x: len(np.shape(x))<2,Y_pred))):
			Y_pred = np.hstack(Y_pred)[:,np.newaxis]
		else:
			Y_pred=np.vstack(Y_pred)
	else:
		Y_pred = Y_pred[0]
		if len(np.shape(Y_pred))<2:
			Y_pred=Y_pred[:,np.newaxis]
	if len(final_latent) > 1:
		final_latent=np.vstack(final_latent)
	else:
		final_latent = final_latent[0]
	if len(Y_true) > 1:
		if all(list(map(lambda x: len(np.shape(x))<2,Y_true))):
			Y_true = np.hstack(Y_true)[:,np.newaxis]
		else:
			Y_true=np.vstack(Y_true)
	else:
		Y_true = Y_true[0]
		if len(np.shape(Y_true))<2:
			Y_true=Y_true[:,np.newaxis]
	print(Y_pred,Y_true)
	#print(np.hstack([Y_pred,Y_true]))
	if output_latent:
		return Y_pred, Y_true, final_latent, None
	else:
		return Y_pred

def train_decoder_(model, x, z):
	model.vae.train(True)
	for param in model.parameters():
		param.requires_grad = False
	for param in model.vae.decoder.parameters():
		param.requires_grad = True
	loss_fn = nn.BCELoss(reduction='sum')
	x_hat = model.decode(z)
	if type(x_hat) != type([]):
		x_hat = [x_hat]
	loss = sum([loss_func(x_h, x) for x_h in x_hat])
	loss.backward()
	for param in model.parameters():
		param.requires_grad = True
	model.vae.eval()
	return model, loss.item()

def val_decoder_(model, x, z):
	model.vae.eval()
	loss_fn = nn.BCELoss(reduction='sum')
	x_hat = model.decode(z)
	if type(x_hat) != type([]):
		x_hat = [x_hat]
	loss = sum([loss_func(x_h, x) for x_h in x_hat])
	return loss.item()

In [None]:
class MLPFinetuneVAE:

	def __init__(self, mlp_model, n_epochs=None, loss_fn=None, optimizer_vae=None, optimizer_mlp=None, cuda=True, categorical=False, scheduler_opts={}, output_latent=True, train_decoder=False):
		self.model=mlp_model
		#print(self.model)
		self.model.vae.cuda_on = cuda
		if cuda:
			self.model = self.model.cuda()
			#self.model.vae = self.model.vae.cuda()
		self.n_epochs = n_epochs
		self.loss_fn = loss_fn
		self.optimizer_vae = optimizer_vae
		self.optimizer_mlp = optimizer_mlp
		self.cuda = cuda
		if self.optimizer_vae!=None and self.optimizer_mlp!=None:
			self.scheduler_vae = Scheduler(self.optimizer_vae,scheduler_opts) if scheduler_opts else Scheduler(self.optimizer_vae)
			self.scheduler_mlp = Scheduler(self.optimizer_mlp,scheduler_opts) if scheduler_opts else Scheduler(self.optimizer_mlp)
		else:
			self.scheduler_vae = None
			self.scheduler_mlp = None
		self.loss_plt_fname='loss.png'
		self.embed_interval=200
		self.validation_set = False
		self.return_latent = True
		self.categorical = categorical
		self.output_latent = output_latent
		self.train_decoder = train_decoder # FIXME add loss for decoder if selecting this option and freeze other weights when updating decoder, also change forward function to include reconstruction, change back when done
		self.train_fn = train_mlp
		self.val_fn = val_mlp
		self.test_fn = test_mlp

	def fit(self, train_data):
		loss_list = []
		model = self.model
		print(model)
		best_model=copy.deepcopy(self.model)
		plt_data={'loss':[],'lr_vae':[],'lr_mlp':[], 'val_loss':[]}
		for epoch in range(self.n_epochs):
			print(epoch)
			model, loss = self.train_fn(model, train_data, self.loss_fn, self.optimizer_vae, self.optimizer_mlp, self.cuda,categorical=self.categorical, train_decoder=self.train_decoder)
			self.scheduler_vae.step()
			self.scheduler_mlp.step()
			plt_data['loss'].append(loss)
			print("Epoch {}: Loss {}".format(epoch,loss))
			if self.validation_set:
				model, val_loss = self.val_fn(model, self.validation_set, self.loss_fn, self.cuda,categorical=self.categorical, train_decoder=self.train_decoder)
				plt_data['val_loss'].append(val_loss)
				print("Epoch {}: Val-Loss {}".format(epoch,val_loss))
			plt_data['lr_vae'].append(self.scheduler_vae.get_lr())
			plt_data['lr_mlp'].append(self.scheduler_mlp.get_lr())
			loss = loss if not self.validation_set else val_loss
			loss_list.append(loss)
			if loss <= min(loss_list): # next get models for lowest reconstruction and kl, 3 models
				best_model=copy.deepcopy(model)
				best_epoch=epoch
		self.training_plot_data=plt_data
		if 0:
			plts=Plotter([Plot(k,'epoch','lr' if 'loss' not in k else k,
						  pd.DataFrame(np.vstack((range(len(plt_data[k])),plt_data[k])).T,
									   columns=['x','y'])) for k in plt_data if plt_data[k]],animation=False)
			plts.write_plots(self.loss_plt_fname)
		self.min_loss = min(plt_data['loss'])
		if self.validation_set:
			self.min_val_loss = min(plt_data['val_loss'])
		else:
			self.min_val_loss = -1
		self.best_epoch = best_epoch
		self.model = best_model
		return self

	def add_validation_set(self, validation_data):
		self.validation_set=validation_data

	def predict(self, test_data):
		return self.test_fn(self.model, test_data, self.categorical, self.cuda, self.output_latent)

In [None]:
class VAE_MLP(nn.Module):

	# add ability to train decoderF
	def __init__(self, vae_model, n_output, categorical=False, hidden_layer_topology=[100,100,100], dropout_p=0.2, add_softmax=False):
		super(VAE_MLP,self).__init__()
		self.vae = vae_model
		self.n_output = n_output
		self.categorical = categorical
		self.add_softmax = add_softmax
		self.topology = [self.vae.n_latent]+(hidden_layer_topology if hidden_layer_topology else [])
		self.mlp_layers = []
		self.dropout_p=dropout_p
		if len(self.topology)>1:
			for i in range(len(self.topology)-1):
				layer = nn.Linear(self.topology[i],self.topology[i+1])
				torch.nn.init.xavier_uniform_(layer.weight)
				self.mlp_layers.append(nn.Sequential(layer,nn.ReLU(),nn.Dropout(self.dropout_p)))
		self.output_layer = nn.Linear(self.topology[-1],self.n_output)
		torch.nn.init.xavier_uniform_(self.output_layer.weight)
		self.mlp_layers.extend([self.output_layer]+([nn.Softmax()] if self.add_softmax else []))#+([nn.LogSoftmax()] if self.categorical else []))
		self.mlp = nn.Sequential(*self.mlp_layers)
		self.output_z=False

	def forward(self,x):
		z=self.vae.get_latent_z(x)
		return self.mlp(z), z

	def decode(self,z):
		return self.vae.decoder(z)

	def forward_embed(self,x):
		out=self.vae.get_latent_z(x)
		recon=self.vae.decoder(out)
		return self.mlp(out), out, recon

	def toggle_latent_z(self):
		if self.output_z:
			self.output_z=False
		else:
			self.output_z=True

	def forward_predict(self,x):
		if self.output_z:
			return self.vae.get_latent_z(x)
		else:
			return self.mlp(self.vae.get_latent_z(x))

In [None]:
class MLP(nn.Module):
	def __init__(self, n_input, hidden_topology, dropout_p, n_outputs=1, binary=True, softmax=False, relu_out=False):
		super(MLP,self).__init__()
		self.topology = [n_input]+hidden_topology+[n_outputs]
		layers = [nn.Linear(self.topology[i],self.topology[i+1]) for i in range(len(self.topology)-2)]
		for layer in layers:
			torch.nn.init.xavier_uniform_(layer.weight)
		self.layers = [nn.Sequential(layer,nn.LeakyReLU(),nn.Dropout(p=dropout_p)) for layer in layers]
		output_layer = nn.Linear(self.topology[-2],self.topology[-1])
		torch.nn.init.xavier_uniform_(output_layer.weight)
		if binary:
			output_transform = nn.Sigmoid()
		elif softmax:
			output_transform = nn.Softmax()
		elif relu_out:
			output_transform = nn.ReLU()
		else:
			output_transform = nn.Dropout(p=0.)
		self.layers.append(nn.Sequential(output_layer,output_transform))
		self.mlp = nn.Sequential(*self.layers)

	def forward(self, x):
		return self.mlp(x)

In [None]:
class MLPTrainer:
	def __init__(self, mlp_model, n_epoch=300, validation_dataloader=None, optimizer_opts=dict(name='adam',lr=1e-3,weight_decay=1e-4), scheduler_opts=dict(scheduler='warm_restarts',lr_scheduler_decay=0.5,T_max=10,eta_min=5e-8,T_mult=2), class_weights=np.array([]), loss_fn=None, categorical=False):
		self.mlp = mlp_model
		optimizers = {'adam':torch.optim.Adam, 'lbfgs':torch.optim.LBFGS}
		if 'name' not in list(optimizer_opts.keys()):
			optimizer_opts['name']='adam'
		self.optimizer = optimizers[optimizer_opts.pop('name')]
		optimizer_opts={k:v for k,v in optimizer_opts.items() if k in inspect.getargspec(self.optimizer.__init__).args}
		self.optimizer = self.optimizer(self.mlp.parameters(),**optimizer_opts)
		self.scheduler = Scheduler(optimizer=self.optimizer,opts=scheduler_opts)
		self.n_epoch = n_epoch
		self.validation_dataloader = validation_dataloader
		self.class_weights = class_weights
		self.loss_fn = loss_fn#loss_functions[loss_fn]
		self.categorical=categorical

	def calc_loss(self, y_pred, y_true):
		loss=self.loss_fn(y_pred,y_true)
		return loss

	def train_loop(self, train_dataloder):
		self.mlp.train(True)
		running_loss = 0.
		for i, (X,y_true) in enumerate(train_dataloder):
			if torch.cuda.is_available():
				X=X.cuda()
				y_true=y_true.cuda()
			if self.categorical:
				y_true=y_true.argmax(1).long()
			y_pred = self.mlp(X)
			loss=self.calc_loss(y_pred,y_true)
			train_loss=loss.item()
			running_loss += train_loss
			self.optimizer.zero_grad()
			loss.backward()
			self.optimizer.step()
		self.scheduler.step()
		return running_loss

	def val_loop(self, val_dataloader):
		self.mlp.train(False)
		running_loss = 0.
		with torch.no_grad():
			for i, (X,y_true) in enumerate(val_dataloader):
				if torch.cuda.is_available():
					X=X.cuda()
					y_true=y_true.cuda()
				if self.categorical:
					y_true=y_true.argmax(1).long()
				y_pred = self.mlp(X)
				loss = self.calc_loss(y_pred,y_true)
				val_loss=loss.item()
				running_loss += val_loss
		return running_loss

	def test_loop(self, test_dataloader):
		self.mlp.train(False)
		y_pred = []
		running_loss = 0.
		with torch.no_grad():
			for i, batch in enumerate(test_dataloader):
				X = batch[0]
				if torch.cuda.is_available():
					X=X.cuda()
				y_pred.append(self.mlp(X).detach().cpu())
			y_pred = torch.cat(y_pred,0).numpy()
			if self.categorical:
				y_pred=y_pred.argmax(1)
		return y_pred

	def fit(self, train_dataloader, verbose=True):
		val_losses = []
		importances = {}
		for epoch in range(self.n_epoch):
			train_loss = self.train_loop(train_dataloader)
			val_loss = self.val_loop(self.validation_dataloader)
			val_losses.append(val_loss)
			if verbose:
				print("Epoch {}: Train Loss {}, Val Loss {}".format(epoch,train_loss,val_loss))
			if val_loss <= min(val_losses):
				min_val_loss = val_loss
				best_epoch = epoch
				best_model = copy.deepcopy(self.mlp)
		self.min_val_loss=min_val_loss
		self.best_epoch=best_epoch
		print('Min Val Loss {}, Best Epoch {}'.format(min_val_loss,best_epoch))
		self.mlp = best_model
		return self

	def predict(self, test_dataloader):
		y_pred = self.test_loop(test_dataloader)
		return y_pred

	def fit_predict(self, train_dataloader, test_dataloader):
		return self.fit(train_dataloader)[0].predict(test_dataloader)