In [1]:
import random
from dataclasses import dataclass

import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Model, GPT2Config

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
A = 5 # number of actions
N = 80000 # number of offline samples
n = 500
eps = 1e-10

In [4]:
def generate_B(A=5, N=80000, n=500):
    dsets, max_idx, means, ps = [], [], [], []
    for i in range(N):
        p_1 = np.random.dirichlet(np.ones(A))
        p_2 = np.zeros(A)
        idx = np.random.choice(A)
        p_2[idx] = 1
        w = (np.random.choice(10) + 1) / 10
        p = (1 - w) * p_1 + w * p_2
        ps.append(p)
        
        a = np.random.choice(A, n, p=p)
        mu = np.random.rand(A)
        means.append(mu)
        max_idx.append(np.argmax(mu))
        r = np.random.normal(mu[a], 0.3)
        
        a_one_hot = np.zeros((n, A))
        a_one_hot[np.arange(n), a] = 1

        X = np.zeros((n, A + 3), np.float32)
        X[:, 0] = 1
        X[:, 1:A + 1] = a_one_hot
        X[:, -2] = 1
        X[:, -1] = r
        dsets.append(X)
    return dsets, max_idx, means, ps

In [5]:
class BanditDataset(Dataset):
    def __init__(self, dsets, max_idx, means, ps):
        self.dsets = dsets
        self.max_idx = max_idx
        self.means = means
        self.ps = ps
        
        self.first = np.zeros((1, A + 3), dtype=np.float32)
        self.first[0, 0] = 1
    
    def __len__(self):
        return len(self.dsets)

    def __getitem__(self, idx):
        sample_ds = self.dsets[idx]
        np.random.shuffle(sample_ds)
        sample_ds = torch.from_numpy(np.concatenate((self.first, sample_ds)))
        sample_max_idx = self.max_idx[idx]
        sample_means = self.means[idx]
        sample_ps = self.ps[idx]
        
        return sample_ds, sample_max_idx, sample_means, sample_ps

In [6]:
class TransformerModel(nn.Module):
    def __init__(self, n_states, n_positions=501, n_embd=32, n_layer=4, n_head=4):
        super(TransformerModel, self).__init__()
        configuration = GPT2Config(
            n_positions=n_positions,
            n_embd=n_embd,
            n_layer=n_layer,
            n_head=n_head,
        )
        self.name = f"gpt2_embd={n_embd}_layer={n_layer}_head={n_head}"

        self.n_positions = n_positions
        self.n_dims = n_states
        self._read_in = nn.Linear(n_states + 3, n_embd)
        self._backbone = GPT2Model(configuration)
        self._read_out = nn.Linear(n_embd, 5)
        self._flatten = nn.Flatten(0, 1)

        for w in self._backbone.wpe.parameters():
            w.data.fill_(0)
        self._backbone.wpe.weight.requires_grad=False

    def forward(self, X):
        embeds = self._read_in(X)
        output = self._backbone(inputs_embeds=embeds).last_hidden_state
        logit = self._read_out(output)[:, 1:]
        logit = self._flatten(logit)

        return logit

In [7]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, max_idx, mu, p) in enumerate(train_dataloader):
        X = X.to(device)
        pred = model(X)
    
        max_idx = max_idx.repeat_interleave(500).to(device)
        loss = loss_fn(pred, max_idx)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [8]:
def test(dataloader, model, loss_fn):
    num_batches = len(dataloader)
    model.eval()
    val_loss = 0
    ds = []
    with torch.no_grad():
        for X, max_idx, mu, p in dataloader:
            X = X.to(device)
            pred = model(X)
        
            max_idx = max_idx.repeat_interleave(500).to(device)
            loss = loss_fn(pred, max_idx)

            mu = mu.to(device)
            d = torch.dot(torch.softmax(pred[-1], 0), mu[-1].float())
            if d < torch.max(mu[-1]) - 0.1 and torch.max(p[-1]) < 1:
                print(mu[-1].cpu().detach().numpy(), torch.softmax(pred[-1], 0).cpu().detach().numpy(), p[-1].cpu().detach().numpy())
            ds.append(torch.max(mu[-1]).item() - d.item())
            val_loss += loss.item()
    val_loss /= num_batches
    print(np.array(ds).mean())
    print(f"Val loss: {val_loss:>8f} \n")
    return val_loss

In [9]:
model = TransformerModel(n_states=5)
model.to(device)
model = nn.DataParallel(model)

In [10]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW([param for param in model.parameters() if param.requires_grad == True])

In [11]:
dsets_train, max_idx_train, means_train, ps_train = generate_B(N=N)
dsets_val, max_idx_val, means_val, ps_val = generate_B(N=N//4)

train_data = BanditDataset(dsets_train, max_idx_train, means_train, ps_train)
val_data = BanditDataset(dsets_val, max_idx_val, means_val, ps_val)

train_dataloader = DataLoader(train_data, batch_size=64)
val_dataloader = DataLoader(val_data, batch_size=64)

In [None]:
epochs = 300
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
cur_val_loss = np.inf
cur_epoch = 1
cur_state_dict = None
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    if test(val_dataloader, model, loss_fn) < cur_val_loss:
        cur_epoch = t + 1
        cur_state_dict = model.module.state_dict()
print("Done!")

Epoch 1
-------------------------------
loss: 1.736755  [   64/80000]
loss: 1.611459  [ 6464/80000]
loss: 1.049074  [12864/80000]
loss: 1.057561  [19264/80000]
loss: 0.841092  [25664/80000]
loss: 0.845541  [32064/80000]
loss: 1.081182  [38464/80000]
loss: 0.785529  [44864/80000]
loss: 0.854464  [51264/80000]
loss: 0.775369  [57664/80000]
loss: 0.768404  [64064/80000]
loss: 0.838525  [70464/80000]
loss: 0.957628  [76864/80000]
[0.540 0.652 0.350 0.089 0.680] [0.162 0.093 0.247 0.009 0.490] [0.903 0.056 0.001 0.032 0.009]
[0.001 0.033 0.567 0.453 0.113] [0.008 0.004 0.206 0.076 0.707] [0.026 0.030 0.027 0.916 0.001]
[0.909 0.541 0.670 0.745 0.874] [0.358 0.009 0.053 0.491 0.089] [0.796 0.078 0.002 0.013 0.111]
[0.687 0.841 0.667 0.277 0.485] [0.083 0.349 0.469 0.001 0.098] [0.654 0.172 0.075 0.039 0.059]
[0.613 0.941 0.801 0.746 0.456] [0.257 0.608 0.096 0.037 0.003] [0.033 0.007 0.196 0.023 0.742]
[0.932 0.728 0.064 0.537 0.281] [0.798 0.036 0.001 0.028 0.138] [0.141 0.057 0.363 0.433 0

In [None]:
torch.save({
            'epoch': cur_epoch,
            'model_state_dict': cur_state_dict,
            }, 'transformer_model_optimal.pt')