In [1]:
pip install transformers



In [2]:
import random
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Model, GPT2Config

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [15]:
A = 5 # number of actions
N = 80000 # number of offline samples
n = 500
eps = 1e-10

In [5]:
def generate_B(A=5, N=80000, n=500):
    dsets, v_dsets, actions, coefficients, opt_idxs = [], [], [], [], []
    for _ in range(N):
        p_1 = np.random.dirichlet(np.ones(A))
        p_2 = np.zeros(A)
        idx = np.random.choice(A)
        p_2[idx] = 1
        w = (np.random.choice(10) + 1) / 10
        p = (1 - w) * p_1 + w * p_2

        a = np.random.choice(A, n, p=p)
        actions.append(a)
        mu = np.random.rand(A)
        coefficients.append(np.exp(mu[a] - np.dot(p, mu)))
        r = np.random.normal(mu[a], 0.3)
        opt_a = np.argmax(mu)
        opt_idxs.append(np.full(n, opt_a))

        a_one_hot = np.zeros((n, A))
        a_one_hot[np.arange(n), a] = 1

        X = np.zeros((n, A + 3), np.float32)
        X[:, 0] = 1
        X[:, 1:A + 1] = a_one_hot
        X[:, -2] = 1
        X[:, -1] = r
        dsets.append(X)
    return dsets, actions, coefficients, opt_idxs

In [6]:
class BanditDataset(Dataset):
    def __init__(self, dsets, actions, coefficients, opt_idxs):
        self.dsets = dsets
        self.actions = actions
        self.coefficients = coefficients
        self.opt_idxs = opt_idxs

        self.first = np.zeros((1, A + 3), dtype=np.float32)
        self.first[0, 0] = 1

    def __len__(self):
        return len(self.dsets)

    def __getitem__(self, idx):
        p = np.random.permutation(n)
        sample_ds = self.dsets[idx]
        sample_ds = np.concatenate((self.first, sample_ds[p]))
        sample_actions = self.actions[idx][p]
        sample_coefficients = self.coefficients[idx][p]
        sample_opt_idxs = self.opt_idxs[idx][p]

        return sample_ds, sample_actions, sample_coefficients, sample_opt_idxs

In [7]:
class TransformerModel(nn.Module):
    def __init__(self, n_states, n_positions=501, n_embd=32, n_layer=4, n_head=4):
        super(TransformerModel, self).__init__()
        configuration = GPT2Config(
            n_positions=n_positions,
            n_embd=n_embd,
            n_layer=n_layer,
            n_head=n_head,
        )
        self.name = f"gpt2_embd={n_embd}_layer={n_layer}_head={n_head}"

        self.n_positions = n_positions
        self.n_dims = n_states
        self._read_in = nn.Linear(n_states + 3, n_embd)
        self._backbone = GPT2Model(configuration)
        self._read_out = nn.Linear(n_embd, 5)
        self._flatten = nn.Flatten(0, 1)

        for w in self._backbone.wpe.parameters():
            w.data.fill_(0)
        self._backbone.wpe.weight.requires_grad=False

    def forward(self, X):
        embeds = self._read_in(X)
        output = self._backbone(inputs_embeds=embeds).last_hidden_state
        logit = self._read_out(output)[:, :-1]
        logit = self._flatten(logit)
        return logit

In [8]:
def loss_fn(pred, a, c):
    ce_loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
    return torch.mean(ce_loss_fn(pred, a) * c)

In [16]:
dsets_train, actions_train, coefficients_train, opt_idxs_train = generate_B(N=N)
dsets_val, actions_val, coefficients_val, opt_idxs_val = generate_B(N=N//4)

data_train = BanditDataset(dsets_train, actions_train, coefficients_train, opt_idxs_train)
data_val = BanditDataset(dsets_val, actions_val, coefficients_val, opt_idxs_val)

train_dataloader = DataLoader(data_train, batch_size=76)
val_dataloader = DataLoader(data_val, batch_size=76)

In [10]:
model = TransformerModel(n_states=5)
model.to(device)
model = nn.DataParallel(model)

In [11]:
optimizer = torch.optim.AdamW([param for param in model.parameters() if param.requires_grad == True])

In [12]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, a, c, opt) in enumerate(train_dataloader):
        X = X.to(device)
        pred = model(X)
        a = a.flatten().to(device)
        c = c.flatten().to(device)
        opt = opt.flatten().to(device)

        loss = loss_fn(pred, a, c)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        prob = F.softmax(pred, dim=1)

        opt_prob = prob[np.arange(prob.shape[0]), opt]
        next_prob = prob[np.arange(prob.shape[0]), a]
        comp = opt_prob - next_prob
        mse = torch.mean(torch.pow(comp, 2))

        if batch % 10 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            print("Comparison:", mse)

In [13]:
def test(dataloader, model, loss_fn):
    num_batches = len(dataloader)
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for X, a, c, opt in dataloader:
            X = X.to(device)
            pred = model(X)
            a = a.flatten().to(device)
            c = c.flatten().to(device)
            opt = opt.flatten().to(device)

            loss = loss_fn(pred, a, c)
            val_loss += loss.item()

    val_loss /= num_batches
    print(f"Val loss: {val_loss:>8f} \n")
    return val_loss

In [None]:
epochs = 100
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
cur_val_loss = np.inf
cur_epoch = 1
cur_state_dict = None
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    cur_epoch = t + 1
    cur_state_dict = model.module.state_dict()
    test(val_dataloader, model, loss_fn)
    torch.save({
        'epoch': cur_epoch,
        'model_state_dict': cur_state_dict,
        }, 'transformer_model.pt')
    # if test(val_dataloader, model, loss_fn) < cur_val_loss:
    #     cur_epoch = t + 1
    #     cur_state_dict = model.module.state_dict()
    #     torch.save({
    #         'epoch': cur_epoch,
    #         'model_state_dict': cur_state_dict,
    #         }, 'transformer_model.pt')
print("Done!")

Epoch 1
-------------------------------
loss: 0.924297  [   76/80000]
Comparison: tensor(0.2727, device='cuda:0', grad_fn=<MeanBackward0>)
loss: 0.986292  [  836/80000]
Comparison: tensor(0.2665, device='cuda:0', grad_fn=<MeanBackward0>)
loss: 1.016658  [ 1596/80000]
Comparison: tensor(0.2226, device='cuda:0', grad_fn=<MeanBackward0>)
loss: 1.000330  [ 2356/80000]
Comparison: tensor(0.2538, device='cuda:0', grad_fn=<MeanBackward0>)
loss: 0.883904  [ 3116/80000]
Comparison: tensor(0.2538, device='cuda:0', grad_fn=<MeanBackward0>)
loss: 0.994679  [ 3876/80000]
Comparison: tensor(0.2279, device='cuda:0', grad_fn=<MeanBackward0>)
loss: 0.959305  [ 4636/80000]
Comparison: tensor(0.2444, device='cuda:0', grad_fn=<MeanBackward0>)
loss: 0.877912  [ 5396/80000]
Comparison: tensor(0.2803, device='cuda:0', grad_fn=<MeanBackward0>)
loss: 0.905468  [ 6156/80000]
Comparison: tensor(0.2756, device='cuda:0', grad_fn=<MeanBackward0>)
loss: 0.939286  [ 6916/80000]
Comparison: tensor(0.2658, device='cuda