In [None]:
!sudo apt-get update
!sudo apt-get install gcc

!sudo apt-get build-dep mesa
!sudo apt-get install llvm-dev
!sudo apt-get install freeglut3 freeglut3-dev

!sudo apt-get install python3-dev

!sudo apt-get install build-essential

!sudo apt install curl git libgl1-mesa-dev libgl1-mesa-glx libglew-dev \
        libosmesa6-dev software-properties-common net-tools unzip vim \
        virtualenv wget xpra xserver-xorg-dev libglfw3-dev patchelf

#!sudo apt-get install -y libglew-dev

!pip install pytorch_lightning
!pip install einops

!pip install networkx
!pip install pybind11
!pip install graph-walker
!pip install glfw


Hit:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease
Hit:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Hit:3 http://archive.ubuntu.com/ubuntu jammy InRelease
Hit:4 http://archive.ubuntu.com/ubuntu jammy-updates InRelease
Hit:5 http://security.ubuntu.com/ubuntu jammy-security InRelease
Hit:6 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Hit:7 https://ppa.launchpadcontent.net/c2d4u.team/c2d4u4.0+/ubuntu jammy InRelease
Hit:8 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:9 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Hit:10 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Reading package lists... Done
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
gcc is already the newest version (4:11.2.0-1ubuntu1).
gcc set to manually installed.
0 upgraded, 0 newly installed, 0 to remove and 50 no

In [None]:
!wget https://roboti.us/download/mujoco200_linux.zip

!wget https://roboti.us/file/mjkey.txt
!mkdir /root/.mujoco

### mujoco 210
#!tar -xf mujoco210-linux-x86_64.tar.gz -C /.mujoco/
#!ls -alh /.mujoco/mujoco210

### mujoco 200
!unzip mujoco200_linux.zip -d /root/.mujoco/
!cp -r /root/.mujoco/mujoco200_linux /root/.mujoco/mujoco200

!mv mjkey.txt /root/.mujoco/

!cp -r /root/.mujoco/mujoco200/bin/* /usr/lib/

!ls -alh /root/.mujoco/
%env LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/.mujoco/mujoco200/bin
!pip install mujoco_py==2.0.2.8




In [None]:
exit()


In [None]:
!pip install Cython==3.0.0a10

In [None]:
%env LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/.mujoco/mujoco200/bin

import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from collections import defaultdict
from typing import Optional, Mapping, Tuple, Union
import logging
from functools import partial
import math
import numpy as np
from scipy import special as ss

import torch.nn.functional as F


from pytorch_lightning.utilities import rank_zero_only
from einops import rearrange, repeat

contract = torch.einsum

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import sys
sys.path.append('/content/drive/MyDrive/6_8200_project/S4')

from models.s4.s4 import S4Block as S4  # Can use full version instead of minimal S4D standalone below
from models.s4.s4d import S4D
from tqdm.auto import tqdm


In [None]:
%env LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/.mujoco/mujoco200/bin

import gym

env = gym.make('Walker2d-v3')
env.reset()
env.step(env.action_space.sample())
env.close()
print("mujoco-py check passed")

In [None]:
import gym
# import walker
from gym.envs.registration import registry, register


In [None]:
def split_train_val(train, val_split):
    train_len = int(len(train) * (1.0-val_split))
    train, val = torch.utils.data.random_split(
        train,
        (train_len, len(train) - train_len),
        generator=torch.Generator().manual_seed(42),
    )
    return train, val


class S4Model(nn.Module):

    def __init__(
        self,
        d_input,
        d_output=10,
        d_model=256,
        n_layers=4,
        dropout=0.2,
        prenorm=False,
    ):
        super().__init__()

        self.prenorm = prenorm

        # Linear encoder (d_input = 1 for grayscale and 3 for RGB)
        self.encoder = nn.Linear(d_input, d_model)

        # Stack S4 layers as residual blocks
        self.s4_layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.dropouts = nn.ModuleList()
        for _ in range(n_layers):
            self.s4_layers.append(
                S4D(d_model, dropout=dropout, transposed=True, lr=min(0.001, lr))
            )
            self.norms.append(nn.LayerNorm(d_model))
            self.dropouts.append(dropout_fn(dropout))

        # Linear decoder
        self.decoder = nn.Linear(d_model, d_output)

    def forward(self, x):
        """
        Input x is shape (B, L, d_input)
        """
        x = self.encoder(x)  # (B, L, d_input) -> (B, L, d_model)

        x = x.transpose(-1, -2)  # (B, L, d_model) -> (B, d_model, L)
        for layer, norm, dropout in zip(self.s4_layers, self.norms, self.dropouts):
            # Each iteration of this loop will map (B, d_model, L) -> (B, d_model, L)

            z = x
            if self.prenorm:
                # Prenorm
                z = norm(z.transpose(-1, -2)).transpose(-1, -2)

            # Apply S4 block: we ignore the state input and output
            z, _ = layer(z)

            # Dropout on the output of the S4 block
            z = dropout(z)

            # Residual connection
            x = z + x

            if not self.prenorm:
                # Postnorm
                x = norm(x.transpose(-1, -2)).transpose(-1, -2)

        x = x.transpose(-1, -2)

        # Pooling: average pooling over the sequence length
        x = x.mean(dim=1)

        # Decode the outputs
        x = self.decoder(x)  # (B, d_model) -> (B, d_output)

        return x



def setup_optimizer(model, lr, weight_decay, epochs):
    """
    S4 requires a specific optimizer setup.

    The S4 layer (A, B, C, dt) parameters typically
    require a smaller learning rate (typically 0.001), with no weight decay.

    The rest of the model can be trained with a higher learning rate (e.g. 0.004, 0.01)
    and weight decay (if desired).
    """

    # All parameters in the model
    all_parameters = list(model.parameters())

    # General parameters don't contain the special _optim key
    params = [p for p in all_parameters if not hasattr(p, "_optim")]

    # Create an optimizer with the general parameters
    optimizer = optim.AdamW(params, lr=lr, weight_decay=weight_decay)

    # Add parameters with special hyperparameters
    hps = [getattr(p, "_optim") for p in all_parameters if hasattr(p, "_optim")]
    hps = [
        dict(s) for s in sorted(list(dict.fromkeys(frozenset(hp.items()) for hp in hps)))
    ]  # Unique dicts
    for hp in hps:
        params = [p for p in all_parameters if getattr(p, "_optim", None) == hp]
        optimizer.add_param_group(
            {"params": params, **hp}
        )

    # Create a lr scheduler
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=patience, factor=0.2)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

    # Print optimizer info
    keys = sorted(set([k for hp in hps for k in hp.keys()]))
    for i, g in enumerate(optimizer.param_groups):
        group_hps = {k: g.get(k, None) for k in keys}
        print(' | '.join([
            f"Optimizer group {i}",
            f"{len(g['params'])} tensors",
        ] + [f"{k} {v}" for k, v in group_hps.items()]))

    return optimizer, scheduler






In [None]:
###############################################################################
# Everything after this point is standard PyTorch training!
###############################################################################

# Training
def train():
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    pbar = tqdm(enumerate(trainloader))
    for batch_idx, (inputs, targets) in pbar:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        pbar.set_description(
            'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (batch_idx, len(trainloader), train_loss/(batch_idx+1), 100.*correct/total, correct, total)
        )


def eval(epoch, dataloader, checkpoint=False):
    global best_acc
    model.eval()
    eval_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        pbar = tqdm(enumerate(dataloader))
        for batch_idx, (inputs, targets) in pbar:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            eval_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            log_data = [batch_idx, len(dataloader), eval_loss/(batch_idx+1), 100.*correct/total, correct, total]
            csv_writer.writerow(log_data)

            pbar.set_description(
                'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' %
                (batch_idx, len(dataloader), eval_loss/(batch_idx+1), 100.*correct/total, correct, total)
            )

    # Save checkpoint.
    if checkpoint:
        acc = 100.*correct/total
        if acc > best_acc:
            state = {
                'model': model.state_dict(),
                'acc': acc,
                'epoch': epoch,
            }
            if not os.path.isdir('/content/drive/MyDrive/6_8200_project/S4/checkpoint'):
                os.mkdir('/content/drive/MyDrive/6_8200_project/S4/checkpoint')
            torch.save(state, '/content/drive/MyDrive/6_8200_project/S4/checkpoint/walker_ckpt.pth')
            best_acc = acc

        return acc


In [None]:
if tuple(map(int, torch.__version__.split('.')[:2])) == (1, 11):
    print("WARNING: Dropout is bugged in PyTorch 1.11. Results may be worse.")
    dropout_fn = nn.Dropout
if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 12):
    dropout_fn = nn.Dropout1d
else:
    dropout_fn = nn.Dropout2d

In [None]:

## from infos.py from official d4rl github repo

REF_MAX_SCORE = {
    'halfcheetah' : 12135.0,
    'walker2d' : 4592.3,
    'hopper' : 3234.3,
}

REF_MIN_SCORE = {
    'halfcheetah' : -280.178953,
    'walker2d' : 1.629008,
    'hopper' : -20.272305,
}


## calculated from d4rl datasets

D4RL_DATASET_STATS = {
        'halfcheetah-medium-v2': {
                'state_mean':[-0.06845773756504059, 0.016414547339081764, -0.18354906141757965,
                              -0.2762460708618164, -0.34061527252197266, -0.09339715540409088,
                              -0.21321271359920502, -0.0877423882484436, 5.173007488250732,
                              -0.04275195300579071, -0.036108363419771194, 0.14053793251514435,
                              0.060498327016830444, 0.09550975263118744, 0.06739100068807602,
                              0.005627387668937445, 0.013382787816226482
                ],
                'state_std':[0.07472999393939972, 0.3023499846458435, 0.30207309126853943,
                             0.34417077898979187, 0.17619241774082184, 0.507205605506897,
                             0.2567007839679718, 0.3294812738895416, 1.2574149370193481,
                             0.7600541710853577, 1.9800915718078613, 6.565362453460693,
                             7.466367721557617, 4.472222805023193, 10.566964149475098,
                             5.671932697296143, 7.4982590675354
                ]
            },
        'halfcheetah-medium-replay-v2': {
                'state_mean':[-0.12880703806877136, 0.3738119602203369, -0.14995987713336945,
                              -0.23479078710079193, -0.2841278612613678, -0.13096535205841064,
                              -0.20157982409000397, -0.06517726927995682, 3.4768247604370117,
                              -0.02785065770149231, -0.015035249292850494, 0.07697279006242752,
                              0.01266712136566639, 0.027325302362442017, 0.02316424623131752,
                              0.010438721626996994, -0.015839405357837677
                ],
                'state_std':[0.17019015550613403, 1.284424901008606, 0.33442774415016174,
                             0.3672759234905243, 0.26092398166656494, 0.4784106910228729,
                             0.3181420564651489, 0.33552637696266174, 2.0931615829467773,
                             0.8037433624267578, 1.9044333696365356, 6.573209762573242,
                             7.572863578796387, 5.069749355316162, 9.10555362701416,
                             6.085654258728027, 7.25300407409668
                ]
            },
        'halfcheetah-medium-expert-v2': {
                'state_mean':[-0.05667462572455406, 0.024369969964027405, -0.061670560389757156,
                              -0.22351515293121338, -0.2675151228904724, -0.07545716315507889,
                              -0.05809682980179787, -0.027675075456500053, 8.110626220703125,
                              -0.06136331334710121, -0.17986927926540375, 0.25175222754478455,
                              0.24186332523822784, 0.2519369423389435, 0.5879552960395813,
                              -0.24090635776519775, -0.030184272676706314
                ],
                'state_std':[0.06103534251451492, 0.36054104566574097, 0.45544400811195374,
                             0.38476887345314026, 0.2218363732099533, 0.5667523741722107,
                             0.3196682929992676, 0.2852923572063446, 3.443821907043457,
                             0.6728139519691467, 1.8616976737976074, 9.575807571411133,
                             10.029894828796387, 5.903450012207031, 12.128185272216797,
                             6.4811787605285645, 6.378620147705078
                ]
            },
        'walker2d-medium-v2': {
                'state_mean':[1.218966007232666, 0.14163373410701752, -0.03704913705587387,
                              -0.13814310729503632, 0.5138224363327026, -0.04719110205769539,
                              -0.47288352251052856, 0.042254164814949036, 2.3948874473571777,
                              -0.03143199160695076, 0.04466355964541435, -0.023907244205474854,
                              -0.1013401448726654, 0.09090937674045563, -0.004192637279629707,
                              -0.12120571732521057, -0.5497063994407654
                ],
                'state_std':[0.12311358004808426, 0.3241879940032959, 0.11456084251403809,
                             0.2623065710067749, 0.5640279054641724, 0.2271878570318222,
                             0.3837319612503052, 0.7373676896095276, 1.2387926578521729,
                             0.798020601272583, 1.5664079189300537, 1.8092705011367798,
                             3.025604248046875, 4.062486171722412, 1.4586567878723145,
                             3.7445690631866455, 5.5851287841796875
                ]
            },
        'walker2d-medium-replay-v2': {
                'state_mean':[1.209364652633667, 0.13264022767543793, -0.14371201395988464,
                              -0.2046516090631485, 0.5577612519264221, -0.03231537342071533,
                              -0.2784661054611206, 0.19130706787109375, 1.4701707363128662,
                              -0.12504704296588898, 0.0564953051507473, -0.09991033375263214,
                              -0.340340256690979, 0.03546293452382088, -0.08934258669614792,
                              -0.2992438077926636, -0.5984178185462952
                ],
                'state_std':[0.11929835379123688, 0.3562574088573456, 0.25852200388908386,
                             0.42075422406196594, 0.5202291011810303, 0.15685082972049713,
                             0.36770978569984436, 0.7161387801170349, 1.3763766288757324,
                             0.8632221817970276, 2.6364643573760986, 3.0134117603302,
                             3.720684051513672, 4.867283821105957, 2.6681625843048096,
                             3.845186948776245, 5.4768385887146
                ]
            },
        'walker2d-medium-expert-v2': {
                'state_mean':[1.2294334173202515, 0.16869689524173737, -0.07089081406593323,
                              -0.16197483241558075, 0.37101927399635315, -0.012209027074277401,
                              -0.42461398243904114, 0.18986578285694122, 3.162475109100342,
                              -0.018092676997184753, 0.03496946766972542, -0.013921679928898811,
                              -0.05937029421329498, -0.19549426436424255, -0.0019200450042262673,
                              -0.062483321875333786, -0.27366524934768677
                ],
                'state_std':[0.09932824969291687, 0.25981399416923523, 0.15062759816646576,
                             0.24249176681041718, 0.6758718490600586, 0.1650741547346115,
                             0.38140663504600525, 0.6962361335754395, 1.3501490354537964,
                             0.7641991376876831, 1.534574270248413, 2.1785972118377686,
                             3.276582717895508, 4.766193866729736, 1.1716983318328857,
                             4.039782524108887, 5.891613960266113
                ]
            },
        'hopper-medium-v2': {
                'state_mean':[1.311279058456421, -0.08469521254301071, -0.5382719039916992,
                              -0.07201576232910156, 0.04932365566492081, 2.1066856384277344,
                              -0.15017354488372803, 0.008783451281487942, -0.2848185896873474,
                              -0.18540096282958984, -0.28461286425590515
                ],
                'state_std':[0.17790751159191132, 0.05444620922207832, 0.21297138929367065,
                             0.14530418813228607, 0.6124444007873535, 0.8517446517944336,
                             1.4515252113342285, 0.6751695871353149, 1.5362390279769897,
                             1.616074562072754, 5.607253551483154
                ]
            },
        'hopper-medium-replay-v2': {
                'state_mean':[1.2305138111114502, -0.04371410980820656, -0.44542956352233887,
                              -0.09370097517967224, 0.09094487875699997, 1.3694725036621094,
                              -0.19992674887180328, -0.022861352190375328, -0.5287045240402222,
                              -0.14465883374214172, -0.19652697443962097
                ],
                'state_std':[0.1756512075662613, 0.0636928603053093, 0.3438323438167572,
                             0.19566889107227325, 0.5547984838485718, 1.051029920578003,
                             1.158307671546936, 0.7963128685951233, 1.4802359342575073,
                             1.6540331840515137, 5.108601093292236
                ]
            },
        'hopper-medium-expert-v2': {
                'state_mean':[1.3293815851211548, -0.09836531430482864, -0.5444297790527344,
                              -0.10201650857925415, 0.02277466468513012, 2.3577215671539307,
                              -0.06349576264619827, -0.00374026270583272, -0.1766270101070404,
                              -0.11862941086292267, -0.12097819894552231
                ],
                'state_std':[0.17012375593185425, 0.05159067362546921, 0.18141433596611023,
                             0.16430604457855225, 0.6023368239402771, 0.7737284898757935,
                             1.4986555576324463, 0.7483318448066711, 1.7953159809112549,
                             2.0530025959014893, 5.725032806396484
                ]
            },
    }



In [None]:
dataset = "medium"       # medium / medium-replay / medium-expert
rtg_scale = 1000                # scale to normalize returns to go

dataset_path = f'/content/drive/MyDrive/6_8200_project/data/walker2d-medium-v2.pkl'

env_name = 'Walker2d-v3'
rtg_target = 5000
env_s4_name = f'walker2d-{dataset}-v2'

log_dir = "./s4_runs/"
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

device_name = 'cuda'
device = torch.device(device_name)
print("device set to: ", device)

In [None]:
from datetime import datetime
import csv



In [None]:
start_time = datetime.now().replace(microsecond=0)

start_time_str = start_time.strftime("%y-%m-%d-%H-%M-%S")

prefix = "s4_" + env_s4_name

save_model_name =  prefix + "_model_" + start_time_str + ".pt"
save_model_path = os.path.join(log_dir, save_model_name)
save_best_model_path = save_model_path[:-3] + "_best.pt"

log_csv_name = prefix + "_log_" + start_time_str + ".csv"
log_csv_path = os.path.join(log_dir, log_csv_name)


csv_writer = csv.writer(open(log_csv_path, 'a', 1))
csv_header = (["duration", "num_updates", "action_loss",
               "eval_avg_reward", "eval_avg_ep_len", "eval_d4rl_score"])

csv_writer.writerow(csv_header)

print("=" * 60)
print("start time: " + start_time_str)
print("=" * 60)

print("device set to: " + str(device))
print("dataset path: " + dataset_path)
print("model save path: " + save_model_path)
print("log csv save path: " + log_csv_path)

env = gym.make(env_name)

state_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]

with open(dataset_path, 'rb') as f:
    trajectories = pickle.load(f)

# min_len = 10**4
# states = []
# for traj in trajectories:
#     min_len = min(min_len, traj['observations'].shape[0])
#     states.append(traj['observations'])

# # used for input normalization
# states = np.concatenate(states, axis=0)
# state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6

# print(dataset_path)
# print("num of trajectories in dataset: ", len(trajectories))
# print("minimum trajectory length in dataset: ", min_len)
# print("state mean: ", state_mean.tolist())
# print("state std: ", state_std.tolist())


# ## check if info is correct
# print("is state mean info correct: ", state_mean.tolist() == D4RL_DATASET_STATS[env_d4rl_name]['state_mean'])
# print("is state std info correct: ", state_std.tolist() == D4RL_DATASET_STATS[env_d4rl_name]['state_std'])

traj_mean = D4RL_DATASET_STATS[env_s4_name]['state_mean']
traj_std = D4RL_DATASET_STATS[env_s4_name]['state_std']


trajectories = (trajectories - traj_mean) / traj_std

trainset, valset = split_train_val(trajectories, val_split=0.1)

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
valloader = torch.utils.data.DataLoader(
    valset, batch_size=batch_size, shuffle=False, num_workers=num_workers)


In [None]:
lr = 0.01
weight_decay = 0.01
epochs = 100
# dataset = 'cifar10' #, choices=['mnist', 'cifar10']
grayscale = True
num_workers = 4
batch_size = 64
n_layers = 4
d_model = 128
dropout = 0.1
prenorm = False
resume = False

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# Data
print(f'==> Preparing {env_name} data..')

In [None]:
# Model
print('==> Building model..')
model = S4Model(
    d_input=state_dim,
    d_output=act_dim,
    d_model=d_model,
    n_layers=n_layers,
    dropout=dropout,
    prenorm=prenorm,
)

model = model.to(device)
if device == 'cuda':
    cudnn.benchmark = True

if resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/ckpt.pth')
    model.load_state_dict(checkpoint['model'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer, scheduler = setup_optimizer(
    model, lr=lr, weight_decay=weight_decay, epochs=epochs
)

In [None]:
pbar = tqdm(range(start_epoch, epochs))
for epoch in pbar:
    if epoch == 0:
        pbar.set_description('Epoch: %d' % (epoch))
    else:
        pbar.set_description('Epoch: %d | Val acc: %1.3f' % (epoch, val_acc))
    train()
    val_acc = eval(epoch, valloader, checkpoint=True)
    # eval(epoch, testloader)
    scheduler.step()
    # print(f"Epoch {epoch} learning rate: {scheduler.get_last_lr()}")