In [1]:
import random
from dataclasses import dataclass

import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Model, GPT2Config

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
A = 5 # number of actions
N = 80000 # number of offline samples
n = 500
eps = 1e-10

In [4]:
def generate_B(A=5, N=80000, n=500):
    dsets, max_idx = [], []
    for i in range(N):
        p_1 = np.random.dirichlet(np.ones(A), n)
        p_2 = np.zeros((n, A))
        idx = np.random.choice(np.arange(A), n)
        p_2[np.arange(n), idx] = 1
        w = (np.random.choice(10, (n, 1)) + 1) / 10
        p = (1 - w) * p_1 + w * p_2

        cum = p.cumsum(1)
        u = np.random.rand(n, 1)
        a = (u < cum).argmax(1)

        mu = np.random.rand(A)
        max_idx.append(np.argmax(mu))
        r = np.random.normal(mu[a], 0.3)
        
        a_one_hot = np.zeros((n, A))
        a_one_hot[np.arange(n), a] = 1

        X = np.zeros((n, A + 3), np.float32)
        X[:, 0] = 1
        X[:, 1:A + 1] = a_one_hot
        X[:, -2] = 1
        X[:, -1] = r
        dsets.append(X)
    return dsets, max_idx

In [5]:
class BanditDataset(Dataset):
    def __init__(self, dsets, max_idx):
        self.dsets = dsets
        self.max_idx = max_idx
        self.first = np.zeros((1, A + 3), dtype=np.float32)
        self.first[0, 0] = 1
    
    def __len__(self):
        return len(self.dsets)

    def __getitem__(self, idx):
        sample_ds = self.dsets[idx]
        np.random.shuffle(sample_ds)
        sample_ds = torch.from_numpy(np.concatenate((self.first, sample_ds)))
        sample_max_idx = self.max_idx[idx]
        
        return sample_ds, sample_max_idx

In [6]:
class TransformerModel(nn.Module):
    def __init__(self, n_states, n_positions=501, n_embd=32, n_layer=4, n_head=4):
        super(TransformerModel, self).__init__()
        configuration = GPT2Config(
            n_positions=n_positions,
            n_embd=n_embd,
            n_layer=n_layer,
            n_head=n_head,
            resid_pdrop=0.0,
            embd_pdrop=0.0,
            attn_pdrop=0.0,
        )
        self.name = f"gpt2_embd={n_embd}_layer={n_layer}_head={n_head}"

        self.n_positions = n_positions
        self.n_dims = n_states
        self._read_in = nn.Linear(n_states + 3, n_embd)
        self._backbone = GPT2Model(configuration)
        self._read_out = nn.Linear(n_embd, 5)
        self._softmax = nn.Softmax(dim=2)

        for w in self._backbone.wpe.parameters():
            w.data.fill_(0)
        self._backbone.wpe.weight.requires_grad=False

    def forward(self, X):
        embeds = self._read_in(X)
        output = self._backbone(inputs_embeds=embeds).last_hidden_state
        logit = self._read_out(output)
        prediction = self._softmax(logit)

        return prediction

In [7]:
def loss_fn(preds, max_idx):
    p = torch.gather(preds, 2, max_idx)
    return torch.mean(-torch.log(p + eps)[:, 1:])

In [8]:
def train(dataloader, model, optimizer):
    model.train()
    for batch, (X, max_idx) in enumerate(dataloader):
        X = X.to(device)
        prediction = model.forward(X)
           
        max_idx = torch.zeros((X.shape[0], X.shape[1], 1), dtype=torch.int64) + max_idx[:, None, None]
        max_idx = max_idx.to(device)

        loss = loss_fn(prediction, max_idx)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        loss = loss.item()
        if batch % 100 == 0:
            print(f"Train loss: {loss:>7f}")

In [9]:
def test(dataloader, model):
    num_batches = len(dataloader)
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for X, max_idx in dataloader:
            X = X.to(device)
            prediction = model.forward(X)
        
            max_idx = torch.zeros((X.shape[0], X.shape[1], 1), dtype=torch.int64) + max_idx[:, None, None]
            max_idx = max_idx.to(device)
            
            loss = loss_fn(prediction, max_idx)
            
            val_loss += loss.item()
    val_loss /= num_batches
    print(f"Val loss: {val_loss:>8f} \n")
    return val_loss

In [10]:
model = TransformerModel(n_states=5)
model.to(device)
model = nn.DataParallel(model)

In [11]:
optimizer = torch.optim.AdamW([param for param in model.parameters() if param.requires_grad == True], 1e-4, weight_decay=1e-4)

In [12]:
dsets_train, max_idx_train = generate_B(N=N)
dsets_val, max_idx_val = generate_B(N=N // 4)

train_data = BanditDataset(dsets_train, max_idx_train)
val_data = BanditDataset(dsets_val, max_idx_val)

train_dataloader = DataLoader(train_data, batch_size=64)
val_dataloader = DataLoader(val_data, batch_size=64)

In [13]:
model.train()
for batch, (X, max_idx) in enumerate(train_dataloader):
    X = X.to(device)
    prediction = model.forward(X)
    print(prediction.grad_fn)

    max_idx = torch.zeros((X.shape[0], X.shape[1], 1), dtype=torch.int64) + max_idx[:, None, None]
    max_idx = max_idx.to(device)
    
    loss = loss_fn(prediction, max_idx)
    print(loss.grad_fn)
    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    loss = loss.item()
    if batch % 100 == 0:
        print(f"Train loss: {loss:>7f}")

<SoftmaxBackward0 object at 0x2ab367e77b50>
<MeanBackward0 object at 0x2ab367e77b80>
Train loss: 1.710211
<SoftmaxBackward0 object at 0x2ab367e77b80>
<MeanBackward0 object at 0x2ab367e77b80>
<SoftmaxBackward0 object at 0x2ab367e76590>
<MeanBackward0 object at 0x2ab367e77d60>
<SoftmaxBackward0 object at 0x2ab367e77d60>
<MeanBackward0 object at 0x2ab367e77d60>
<SoftmaxBackward0 object at 0x2ab367e77d60>
<MeanBackward0 object at 0x2ab367e77d60>
<SoftmaxBackward0 object at 0x2ab367e76530>
<MeanBackward0 object at 0x2ab367e77ee0>
<SoftmaxBackward0 object at 0x2ab367e77ee0>
<MeanBackward0 object at 0x2ab367e77ee0>
<SoftmaxBackward0 object at 0x2ab367e77f40>
<MeanBackward0 object at 0x2ab367e77f40>
<SoftmaxBackward0 object at 0x2ab367e77d60>
<MeanBackward0 object at 0x2ab367e75d50>
<SoftmaxBackward0 object at 0x2ab367e75d50>
<MeanBackward0 object at 0x2ab367e75d50>
<SoftmaxBackward0 object at 0x2ab367e76590>
<MeanBackward0 object at 0x2ab367e76590>
<SoftmaxBackward0 object at 0x2ab367e77f40>


In [None]:
epochs = 300
cur_val_loss = np.inf
cur_epoch = 1
cur_state_dict = None
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, optimizer)
    if test(val_dataloader, model) < cur_val_loss:
        cur_epoch = t + 1
        cur_state_dict = model.module.state_dict()
print("Done!")

Epoch 1
-------------------------------
Train loss: 1.707279
Train loss: 1.611381
Train loss: 1.595779
Train loss: 1.559452
Train loss: 0.979866
Train loss: 0.724094
Train loss: 0.694121
Train loss: 0.592415
Train loss: 0.618607
Train loss: 0.549395
Train loss: 0.575568
Train loss: 0.649643
Train loss: 0.601816
Val loss: 0.525050 

Epoch 2
-------------------------------
Train loss: 0.500942
Train loss: 0.493246
Train loss: 0.463754
Train loss: 0.574874
Train loss: 0.446787
Train loss: 0.559339
Train loss: 0.522662
Train loss: 0.506971
Train loss: 0.433394
Train loss: 0.484934
Train loss: 0.443036
Train loss: 0.493373
Train loss: 0.476586
Val loss: 0.425715 

Epoch 3
-------------------------------
Train loss: 0.364750
Train loss: 0.398402
Train loss: 0.415509
Train loss: 0.452759
Train loss: 0.335567
Train loss: 0.403636
Train loss: 0.446505
Train loss: 0.396700
Train loss: 0.440471
Train loss: 0.395744
Train loss: 0.380510
Train loss: 0.487094
Train loss: 0.440457
Val loss: 0.388643 

In [None]:
model

In [None]:
torch.save({
            'epoch': cur_epoch,
            'model_state_dict': cur_state_dict,
            }, 'transformer_model_optimal.pt')