In [11]:
import torch
import wandb
import pandas as pd
import numpy as np
import common_constants as cc
import common_functions as cf
from torch.optim import Adam, SGD
from torch.utils.data import DataLoader
from random import randint

In [12]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [13]:
WANDB_PROJECT_NAME = "AutoEncoder"

In [14]:
CreatePath(cc.MODELS_PATH)

In [15]:
class LinearAutoEncoder(torch.nn.Module):
	def __init__(self):
		super().__init__()
		
		self.encoder = torch.nn.Sequential(
			torch.nn.Linear(677*3, 128),
			torch.nn.ReLU(),
			torch.nn.Linear(128, 128),
			torch.nn.ReLU(),
			torch.nn.Linear(128, 64),
			torch.nn.ReLU(),
			torch.nn.Linear(64, 32),
		)
		
		self.decoder = torch.nn.Sequential(
			torch.nn.Linear(32, 64),
			torch.nn.ReLU(),
			torch.nn.Linear(64, 128),
			torch.nn.ReLU(),
			torch.nn.Linear(128, 128),
			torch.nn.ReLU(),
			torch.nn.Linear(128, 677*3)
		)

	def forward(self, x):
		x = torch.flatten(x, start_dim=1)
		encoded = self.encoder(x)
		decoded = self.decoder(encoded)
		return decoded

In [16]:
train_dataset = cf.AEAcceleratorDataset(cc.TRAIN_TXT_PATH)
train_dataloader = DataLoader(train_dataset,batch_size=cc.TRAIN_BATCH_SIZE,shuffle=False)
test_dataset = cf.AEAcceleratorDataset(cc.TEST_TXT_PATH)
test_dataloader = DataLoader(test_dataset,batch_size=cc.TEST_BATCH_SIZE,shuffle=False)

Linear Model Training

In [None]:
wandb.init(project=WANDB_PROJECT_NAME)

In [17]:
sub_project_path = os.path.join(cc.MODELS_PATH,'LinearModel')
CreatePath(sub_project_path)
models_path = os.path.join(sub_project_path,'models')
CreatePath(models_path)

model = LinearAutoEncoder()
loss_func = torch.nn.MSELoss()
optimizer = Adam(model.parameters()) #SGD(model.parameters())

In [None]:
best_val = -1
best_val_idx = -1
for epoch in range(cc.EPOCHS):
    #train
    model.train()
    epoch_loss = []
    n_batches = len(train_dataloader)
    for idx, data in enumerate(train_dataloader):
        optimizer.zero_grad()
        data = data.to(torch.float32).to(DEVICE)
        flattned_data = torch.flatten(data, start_dim=1)

        prediction = model(data)
    
        loss = loss_func(prediction,flattned_data)
        epoch_loss.append(float(loss))

        loss.backward()
        optimizer.step()

        print("",end='\rEpoch: {}/{} | Batch: {}/{} | loss: {}'.format(epoch,cc.EPOCHS,idx,n_batches,np.mean(epoch_loss)))
    epoch_final_loss = np.mean(epoch_loss)
    del epoch_loss

    #evaluating
    model.eval()
    optimizer.zero_grad()
    validation_loss = []
    for idx, data in enumerate(test_dataloader):
        data = data.to(torch.float32).to(DEVICE)
        flattned_data = torch.flatten(data, start_dim=1)

        prediction = model(data)
        
        loss = loss_func(prediction,flattned_data)
        validation_loss.append(float(loss))
    validation_final_loss = np.mean(validation_loss)
    del validation_loss
    print('\nEpoch: {}/{} | train_loss: {} | val_loss: {}\n'.format(epoch,cc.EPOCHS,epoch_final_loss,validation_final_loss))

    save_model_at = os.path.join(models_path,'epoch_{}.pt'.format(epoch))
    torch.save(model,save_model_at)

    # wandb.log({
    #     "epoch_loss":epoch_final_loss,
    #     "epoch_validation_loss":validation_final_loss
    # },sync=True)

    if best_val_idx==-1 or best_val>validation_final_loss:
        best_val=validation_final_loss
        best_val_idx=epoch

best_path = os.path.join(models_path,'epoch_{}.pt'.format(best_val_idx))
new_best_path = os.path.join(models_path,'best.pt')
if os.path.exists(new_best_path):
    os.remove(new_best_path)
os.rename(best_path,new_best_path)

In [None]:
model = torch.load(os.path.join(models_path,'epoch_0.pt'))
n_test_sample = len(test_dataset)

first = True
last_pred = None
for show_i in range(cc.SHOW_N_TESTS):
    idx = randint(0,n_test_sample-1)

    df = test_dataset.__getitem__(idx,True)

    model_input = torch.flatten(torch.tensor(df.to_numpy()[:,:-1]).unsqueeze(0), start_dim=1).to(torch.float32).to(DEVICE)

    prediction = model(model_input)

    prediction = prediction.view(1,677,3).detach().numpy()[0]
    if first:
        first = False
        last_pred=prediction
    else:
        print((prediction==last_pred).all())
        last_pred = prediction

    new_df_data = {}
    for i,col in enumerate(df.columns):
        if i==3:
            new_df_data[col] = df[col]
            break
        new_df_data[col] = prediction[:,i]

    new_df = pd.DataFrame(new_df_data)

    gt_path = os.path.join(sub_project_path,"GT_{}.jpg".format(show_i))
    cf.PlotRecordData(df,False,False,False,gt_path,False)
    pred_path = os.path.join(sub_project_path,"Prediction_{}.jpg".format(show_i))
    cf.PlotRecordData(new_df,False,False,False,pred_path,False)

Conv

In [None]:
class ConvAutoEncoder(torch.nn.Module):
	def __init__(self):
		super().__init__()

		self.encoder = torch.nn.Sequential(
			torch.nn.Conv1d(3, 8,kernel_size=10, stride=5),
			torch.nn.ReLU(),
			torch.nn.Conv1d(8, 16,kernel_size=10, stride=5),
			torch.nn.ReLU(),
			torch.nn.Conv1d(16, 32,kernel_size=10, stride=5),
			torch.nn.ReLU(),
			torch.nn.Conv1d(32, 64,kernel_size=4, stride=1)
		)
		
		self.decoder = torch.nn.Sequential(
			torch.nn.ConvTranspose1d(64, 32,kernel_size=4, stride=1),
			torch.nn.ReLU(),
			torch.nn.ConvTranspose1d(32, 16,kernel_size=10, stride=5),
			torch.nn.ReLU(),
			torch.nn.ConvTranspose1d(16, 8,kernel_size=10, stride=5),
			torch.nn.ReLU(),
			torch.nn.ConvTranspose1d(8, 3,kernel_size=6, stride=5,padding=2,dilation=7),
		)

	def forward(self, x):
		encoded = self.encoder(x)
		decoded = self.decoder(encoded)
		return decoded

In [None]:
 data = torch.permute(data,(0,2,1)).to(torch.float32).to(DEVICE)