In [1]:
import torch
from torch import nn
from glob import glob
from algo.models.transformer.data import TactileDataset
from torch.utils.data import DataLoader
from tqdm import tqdm

In [10]:
# Define the CNN model
class MyCNNModel(nn.Module):
    def __init__(self):
        super(MyCNNModel, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.flatten = nn.Flatten()
        
        # Fully connected layer to produce the latent representation
        self.fc = nn.Linear(64 * 28 * 28, 1)

    def forward(self, x):
        x = self.features(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

# Instantiate the model
model = MyCNNModel().to('cuda:0')
model(torch.randn(32, 3, 224, 224).to('cuda:0')).shape

torch.Size([32, 1])

In [27]:
bs = 8
seq_len = 3
training_files = all_paths = glob('/common/users/dm1487/inhand_manipulation_data_store/datastore_42_11-23-23/07-47-24/*/*.npz')[:1]
ds = TactileDataset(files=training_files, full_sequence=False, sequence_length=3)
dl = DataLoader(ds, batch_size=bs, shuffle=True)
mse = nn.MSELoss()
optim = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [28]:
import numpy as np
losses = []
for mm in range(1000):
    
    for idx, (cnn_input, lin_input, obs_hist, latent, action, mask) in enumerate(dl):
        # print(cnn_input[0].reshape(bs*seq_len, 1, 224, 224).shape)
        inp = cnn_input[0].reshape(-1, seq_len, 224, 224).to('cuda:0')
        pred = model(inp)
        loss = mse(pred, latent[:, -1, :1].to('cuda:0'))
        optim.zero_grad()
        loss.backward()
        optim.step()
        losses.append(loss.item())
    if mm % 10 == 0:
        print('step', mm, ':', np.mean(losses))
        losses = []

step 0 : 1.4345839023590088
step 10 : 0.3280460398644209
step 20 : 0.20603969715593848
step 30 : 0.18846148140728475
step 40 : 0.14806016029324381
step 50 : 0.16448919028043746
step 60 : 0.2939487662166357
step 70 : 0.14273578071733936
step 80 : 0.15762271652929485
step 90 : 0.23757987096905708
step 100 : 0.1853393204510212
step 110 : 0.15054594688117504
step 120 : 0.1703864661976695
step 130 : 0.23959612026810645
step 140 : 0.1412281469674781
step 150 : 0.163984714448452
step 160 : 0.12123349495232105
step 170 : 0.19771757647395133
step 180 : 0.3563255459070206
step 190 : 0.13081115968525409
step 200 : 0.15955711305141448
step 210 : 0.16909231171011924
step 220 : 0.11591201573610306
step 230 : 0.12099282897543162
step 240 : 0.13511678650975228
step 250 : 0.09150374946184456
step 260 : 0.11378575302660465
step 270 : 0.22658030409365892
step 280 : 0.11702201727312059
step 290 : 0.09904020731337368
step 300 : 0.24800135221375968
step 310 : 0.15056551359593867
step 320 : 0.107787827029824