<a href="https://colab.research.google.com/github/dylanstephens1997/L4DC_Term_Project/blob/main/pytorch_LDT_CVAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
input = torch.load('/content/drive/MyDrive/Colab Notebooks/input_demos_gcg.pt')
actions = torch.load('/content/drive/MyDrive/Colab Notebooks/action_demos_gcg.pt')
if (torch.is_tensor(input)):
    input_t = input
    action_t = actions
else:
    input_t  = torch.from_numpy(input)
    action_t = torch.from_numpy(actions)

In [None]:
# prerequisites
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
import numpy as np

bs = 100

In [None]:
from torch.utils.data import Dataset, DataLoader, TensorDataset
class TeleopData(Dataset):
    def __init__(self, actions, states):
        self.action = actions
        self.state = states

    def __len__(self):
        return len(self.state)

    def __getitem__(self, idx):
        actions = self.action[idx]
        states  = self.state[idx]
        sample = (actions,states)
        return sample

In [None]:
import pandas as pd
#state_action_df = pd.DataFrame({'State': input, 'Actions': actions})
TD = TeleopData(action_t, input_t)
print('\nFirst iteration of data set: ', next(iter(TD)), '\n')
train_dataset = TensorDataset(action_t, input_t)
test_dataset = TensorDataset(action_t[0:200], input_t[0:200])



First iteration of data set:  (tensor([ 0.,  0.,  0.,  0.,  0., -0., -1.], dtype=torch.float64), tensor([295.8856, 415.1305,   0.9998,   0.9774,   0.9996,  -0.8776,   0.9999,
         -0.9755,   0.7004,   1.0000,   0.0000], dtype=torch.float64)) 



In [None]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=30, shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=10, shuffle=False)

In [None]:
print(action_t[0:200])

tensor([[ 0.,  0.,  0.,  ...,  0., -0., -1.],
        [ 0.,  0.,  0.,  ...,  0., -0., -1.],
        [ 0.,  0.,  0.,  ...,  0., -0., -1.],
        ...,
        [ 0.,  0.,  0.,  ...,  0., -0., -1.],
        [ 0.,  0.,  0.,  ...,  0., -0., -1.],
        [ 0.,  0.,  0.,  ...,  0., -0., -1.]], dtype=torch.float64)


In [None]:
for batch_idx, (data, cond) in enumerate(train_loader):
        print("data: ", data.shape )
        print("cond: ", cond.shape )
        data, cond = data.cuda(), cond.cuda()
        break


data:  torch.Size([50, 7])
cond:  torch.Size([50, 11])


In [None]:
print(TD)

<__main__.TeleopData object at 0x7f1ca400bb10>


In [None]:
class CVAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim, c_dim):
        super(CVAE, self).__init__()
        
        # encoder part
        self.fc1 = nn.Linear(x_dim + c_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim + c_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
    
    def encoder(self, x, c):
        concat_input = torch.cat([x, c], 1)
        h = F.relu(self.fc1(concat_input))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h)
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add(mu) # return z sample
    
    def decoder(self, z, c):
        concat_input = torch.cat([z, c], 1)
        h = F.relu(self.fc4(concat_input))
        h = F.relu(self.fc5(h))
        return F.sigmoid(self.fc6(h))
    
    def forward(self, x, c):
        mu, log_var = self.encoder(x.view(-1, 7), c)
        z = self.sampling(mu, log_var)
        return self.decoder(z, c), mu, log_var

# build model
#cond_dim = train_loader.dataset.train_labels.unique().size(0)
cvae = CVAE(x_dim=len(action_t[0]), h_dim1=30, h_dim2=10, z_dim=1, c_dim=len(input_t[0]))
if torch.cuda.is_available():
    cvae.cuda()

In [None]:
cvae
cvae = cvae.double()

In [None]:
optimizer = optim.Adam(cvae.parameters())
# return reconstruction error + KL divergence losses
def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 7), reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

# one-hot encoding (not necessary for us commenting out)
'''
def one_hot(labels, class_size): 
    targets = torch.zeros(labels.size(0), class_size)
    for i, label in enumerate(labels):
        targets[i, label] = 1
    return Variable(targets)'''

'\ndef one_hot(labels, class_size): \n    targets = torch.zeros(labels.size(0), class_size)\n    for i, label in enumerate(labels):\n        targets[i, label] = 1\n    return Variable(targets)'

In [None]:
def train(epoch):
    cvae.train()
    train_loss = 0
    for batch_idx, (data, cond) in enumerate(train_loader):
        data, cond = data.cuda(), cond.cuda()
        optimizer.zero_grad()
        
        recon_batch, mu, log_var = cvae(data.double(), cond.double())
        loss = loss_function(recon_batch, data, mu, log_var)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item() / len(data)))
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))                

In [None]:
def test():
    cvae.eval()
    test_loss= 0
    with torch.no_grad():
        for data, cond in test_loader:
            data, cond = data.cuda(), cond.cuda()
            recon, mu, log_var = cvae(data, cond)
            # sum up batch loss
            test_loss += loss_function(recon, data, mu, log_var).item()
        
    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [None]:
# train
for epoch in range(1, 10):
    train(epoch)
    test()



====> Epoch: 1 Average loss: 9.5727
====> Test set loss: -1.1304
====> Epoch: 2 Average loss: 9.5727
====> Test set loss: -1.1304
====> Epoch: 3 Average loss: 9.5727
====> Test set loss: -1.1304
====> Epoch: 4 Average loss: 9.5727
====> Test set loss: -1.1304
====> Epoch: 5 Average loss: 9.5727
====> Test set loss: -1.1304
====> Epoch: 6 Average loss: 9.5727
====> Test set loss: -1.1304
====> Epoch: 7 Average loss: 9.5727
====> Test set loss: -1.1304
====> Epoch: 8 Average loss: 9.5727
====> Test set loss: -1.1304
====> Epoch: 9 Average loss: 9.5727
====> Test set loss: -1.1304


In [None]:
torch.save(cvae.fc1.weight,'/content/drive/MyDrive/Colab Notebooks/fc1_weights.pt')
torch.save(cvae.fc2.weight,'/content/drive/MyDrive/Colab Notebooks/fc2_weights.pt')
torch.save(cvae.fc31.weight,'/content/drive/MyDrive/Colab Notebooks/fc31_weights.pt')
torch.save(cvae.fc32.weight,'/content/drive/MyDrive/Colab Notebooks/fc32_weights.pt')
torch.save(cvae.fc4.weight,'/content/drive/MyDrive/Colab Notebooks/fc4_weights.pt')
torch.save(cvae.fc5.weight,'/content/drive/MyDrive/Colab Notebooks/fc5_weights.pt')
torch.save(cvae.fc6.weight,'/content/drive/MyDrive/Colab Notebooks/fc6_weights.pt')



In [None]:
test()

====> Test set loss: -1.1304




In [None]:
with torch.no_grad():
    z = torch.ones(2100, 1).cuda()
    c = input_t[0:2100].cuda()

    sample = cvae.decoder(z, c)
print(torch.sum(sample,0))
print(sample)

tensor([5.2961e+01, 1.9656e+03, 3.6072e-02, 4.1915e+02, 2.0947e+03, 4.3015e-02,
        2.4235e-07], device='cuda:0', dtype=torch.float64)
tensor([[2.9167e-02, 9.3462e-01, 1.5220e-05,  ..., 9.9831e-01, 1.5984e-05,
         6.4407e-11],
        [2.9165e-02, 9.3462e-01, 1.5221e-05,  ..., 9.9831e-01, 1.5984e-05,
         6.4412e-11],
        [2.9163e-02, 9.3461e-01, 1.5223e-05,  ..., 9.9831e-01, 1.5985e-05,
         6.4424e-11],
        ...,
        [1.4546e-02, 9.4350e-01, 1.9242e-05,  ..., 9.9439e-01, 3.1133e-05,
         2.1579e-10],
        [1.4546e-02, 9.4351e-01, 1.9240e-05,  ..., 9.9439e-01, 3.1132e-05,
         2.1576e-10],
        [1.4547e-02, 9.4351e-01, 1.9238e-05,  ..., 9.9440e-01, 3.1131e-05,
         2.1574e-10]], device='cuda:0', dtype=torch.float64)


