In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
from utils import device


class Policy(nn.Module):
    def __init__(self, config):
        super(Policy, self).__init__()
        lstm_dim = 256 + config["proprio_dim"] #config["visual_embedding_dim"] -> 8192 = 16x16x32 -> 4096 (8x8x64) --> 2048 (8x8x32) --> 4096 (16x16x16) --> 2048 (16x16x8)
        self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2)
        self.conv1 = nn.Conv2d(
            in_channels=3, out_channels=4, kernel_size=3, padding=2, stride=1
        )
        self.conv2 = nn.Conv2d(
            in_channels=4, out_channels=4, kernel_size=3, padding=2, stride=1
        )
        self.conv3 = nn.Conv2d(
            in_channels=4, out_channels=8, kernel_size=3, padding=2, stride=1
        )
        self.conv4 = nn.Conv2d(
            in_channels=8, out_channels=8, kernel_size=3, padding=2, stride=1
        )
        self.conv5 = nn.Conv2d(
            in_channels=8, out_channels=16, kernel_size=3, padding=2, stride=1
        )
        self.conv6 = nn.Conv2d(
            in_channels=16, out_channels=16, kernel_size=3, padding=2, stride=1
        )
        self.lstm = nn.LSTM(lstm_dim, lstm_dim)  # , batch_first=True)
        self.linear_out = nn.Linear(lstm_dim, config["action_dim"])
        self.optimizer = torch.optim.Adam(
            self.parameters(),
            lr=config["learning_rate"],
            weight_decay=config["weight_decay"],
        )
        self.std = 0.1 * torch.ones(config["action_dim"], dtype=torch.float32)
        self.std = self.std.to(device)
        self.dropout = nn.Dropout(p=0.4)
        return

    def forward_step(self, camera_obs, proprio_obs, lstm_state):
        vis_encoding = F.elu(self.conv1(camera_obs))
        vis_encoding = self.maxpool(vis_encoding)
        vis_encoding = F.elu(self.conv2(vis_encoding))
        vis_encoding = self.maxpool(vis_encoding)
        vis_encoding = F.elu(self.conv3(vis_encoding))
        vis_encoding = self.maxpool(vis_encoding)
        vis_encoding = F.elu(self.conv4(vis_encoding))
        vis_encoding = self.maxpool(vis_encoding)
        vis_encoding = F.elu(self.conv5(vis_encoding))
        vis_encoding = self.maxpool(vis_encoding)
        vis_encoding = F.elu(self.conv6(vis_encoding))
        vis_encoding = self.maxpool(vis_encoding)
        vis_encoding = torch.flatten(vis_encoding, start_dim=1)
        low_dim_input = torch.cat((vis_encoding, proprio_obs), dim=-1).unsqueeze(0)
        low_dim_input = self.dropout(low_dim_input)
        lstm_out, (h, c) = self.lstm(low_dim_input, lstm_state)
        lstm_state = (h, c)
        out = torch.tanh(self.linear_out(lstm_out))
        return out, lstm_state

    def forward(self, camera_obs_traj, proprio_obs_traj, action_traj, feedback_traj):
        losses = []
        lstm_state = None
        for idx in range(len(proprio_obs_traj)):
            mu, lstm_state = self.forward_step(
                camera_obs_traj[idx], proprio_obs_traj[idx], lstm_state
            )
            distribution = Normal(mu, self.std)
            log_prob = distribution.log_prob(action_traj[idx])
            loss = -log_prob * feedback_traj[idx]
            losses.append(loss)
        total_loss = torch.cat(losses).mean()
        return total_loss

    def update_params(
        self, camera_obs_traj, proprio_obs_traj, action_traj, feedback_traj
    ):
        camera_obs = camera_obs_traj.to(device)
        proprio_obs = proprio_obs_traj.to(device)
        action = action_traj.to(device)
        feedback = feedback_traj.to(device)
        self.optimizer.zero_grad()
        loss = self.forward(camera_obs, proprio_obs, action, feedback)
        loss.backward()
        self.optimizer.step()
        training_metrics = {"loss": loss}
        return training_metrics

    def predict(self, camera_obs, proprio_obs, lstm_state):
        camera_obs_th = torch.tensor(camera_obs, dtype=torch.float32).unsqueeze(0)
        proprio_obs_th = torch.tensor(proprio_obs, dtype=torch.float32).unsqueeze(0)
        camera_obs_th = camera_obs_th.to(device)
        proprio_obs_th = proprio_obs_th.to(device)
        with torch.no_grad():
            action_th, lstm_state = self.forward_step(
                camera_obs_th, proprio_obs_th, lstm_state
            )
            action = action_th.detach().cpu().squeeze(0).squeeze(0).numpy()
            action[-1] = binary_gripper(action[-1])
        return action, lstm_state


def binary_gripper(gripper_action):
    if gripper_action >= 0.0:
        gripper_action = 0.9
    elif gripper_action < 0.0:
        gripper_action = -0.9
    return gripper_action


In [None]:
config = {
        "feedback_type": 'cloning_100',
        "task": 'stator_100',
        "proprio_dim": 8,
        "action_dim": 7,
        "visual_embedding_dim": 256,
        "learning_rate": 3e-4,
        "weight_decay": 3e-6,
        "batch_size": 8,#16
    }

In [None]:
policy = Policy(config).to(device)
    
pytorch_total_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)

print(pytorch_total_params)

566163


In [67]:
replay_memory = torch.load("/home/faps/CEILing/CEILing256_v2/data/stators_100/demos_100.dat")
batch = replay_memory.sample(config["batch_size"])
camera_batch, proprio_batch, action_batch, feedback_batch = batch

In [9]:
input_tensor = torch.randn(8,3,256,256)

# input_tensor = preprocess(input_image)

input_batch = input_tensor

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch
    

In [56]:
conv1 = nn.Conv2d(
            in_channels=3, out_channels=4, kernel_size=3, padding=2, stride=1
        )
conv2 = nn.Conv2d(
            in_channels=4, out_channels=4, kernel_size=3, padding=2, stride=1
        )
conv3 = nn.Conv2d(
            in_channels=4, out_channels=8, kernel_size=3, padding=2, stride=1
        )
conv4 = nn.Conv2d(
            in_channels=8, out_channels=8, kernel_size=3, padding=2, stride=1
        )
conv5 = nn.Conv2d(
            in_channels=8, out_channels=16, kernel_size=3, padding=2, stride=1
        )
conv6 = nn.Conv2d(
            in_channels=16, out_channels=16, kernel_size=3, padding=2, stride=1
        )

maxpool = nn.MaxPool2d(kernel_size=3,stride=2)

In [57]:
vis_encoding = conv1(input_batch)
vis_encoding = maxpool(vis_encoding)
print(vis_encoding.shape)

torch.Size([8, 4, 128, 128])


In [58]:
vis_encoding = conv2(vis_encoding)
vis_encoding = maxpool(vis_encoding)
print(vis_encoding.shape)

torch.Size([8, 4, 64, 64])


In [59]:
vis_encoding = conv3(vis_encoding)
vis_encoding = maxpool(vis_encoding)
print(vis_encoding.shape)

torch.Size([8, 8, 32, 32])


In [60]:
vis_encoding = conv4(vis_encoding)
vis_encoding = maxpool(vis_encoding)
print(vis_encoding.shape)

torch.Size([8, 8, 16, 16])


In [61]:
vis_encoding = conv5(vis_encoding)
vis_encoding = maxpool(vis_encoding)
print(vis_encoding.shape)

torch.Size([8, 16, 8, 8])


In [62]:
vis_encoding = conv6(vis_encoding)
vis_encoding = maxpool(vis_encoding)
print(vis_encoding.shape)

torch.Size([8, 16, 4, 4])


In [68]:
training_metrics = policy.update_params(
            camera_batch, proprio_batch, action_batch, feedback_batch
        )


training_metrics

{'loss': tensor(13.8156, device='cuda', grad_fn=<MeanBackward0>)}

In [69]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

In [70]:
count_parameters(policy)

+-------------------+------------+
|      Modules      | Parameters |
+-------------------+------------+
|    conv1.weight   |    108     |
|     conv1.bias    |     4      |
|    conv2.weight   |    144     |
|     conv2.bias    |     4      |
|    conv3.weight   |    288     |
|     conv3.bias    |     8      |
|    conv4.weight   |    576     |
|     conv4.bias    |     8      |
|    conv5.weight   |    1152    |
|     conv5.bias    |     16     |
|    conv6.weight   |    2304    |
|     conv6.bias    |     16     |
| lstm.weight_ih_l0 |   278784   |
| lstm.weight_hh_l0 |   278784   |
|  lstm.bias_ih_l0  |    1056    |
|  lstm.bias_hh_l0  |    1056    |
| linear_out.weight |    1848    |
|  linear_out.bias  |     7      |
+-------------------+------------+
Total Trainable Params: 566163


566163