In [1]:
!pip install transformers



In [2]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Model, GPT2Config
from tqdm import tqdm
import matplotlib.pyplot as plt

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
A = 5 # number of actions
N = 40000 # number of offline datasets
n = 500 # context length
trials = 500 # online regret trials for val/test

In [5]:
def generate_B(A=5, N=40000, n=500):
    dsets, actions = [], []
    for i in tqdm(range(N)):
        mu = np.random.rand(A)

        p_1 = np.random.dirichlet(np.ones(A))
        p_2 = np.zeros(A)
        p_2[np.random.choice(A)] = 1
        w = (np.random.choice(11)) / 10
        p = (1 - w) * p_1 + w * p_2

        a = np.random.choice(A, n, p=p)
        actions.append(a)

        r = np.random.normal(mu[a], 0.3)

        X = np.zeros((n, A + 3), np.float32)
        X[:, [0, -2]] = 1
        X[np.arange(n), a + 1] = 1
        X[:, -1] = r
        dsets.append(X)
    return dsets, actions

In [6]:
class BanditDataset(Dataset):
    def __init__(self, dsets, actions):
        self.dsets = dsets
        self.actions = actions

        self.first = np.zeros((1, A + 3), dtype=np.float32)
        self.first[0, 0] = 1

    def __len__(self):
        return len(self.dsets)

    def __getitem__(self, idx):
        perm = np.random.permutation(n) # shuffled in-context datset to reduce overfitting
        sample_ds = np.concatenate((self.first, self.dsets[idx][perm]))
        return sample_ds, self.actions[idx][perm]

In [7]:
class TransformerModel(nn.Module):
    def __init__(self, n_states, n_positions=501, n_embd=32, n_layer=4, n_head=4):
        super(TransformerModel, self).__init__()
        configuration = GPT2Config(
            n_positions=n_positions,
            n_embd=n_embd,
            n_layer=n_layer,
            n_head=n_head,
        )
        self.name = f"gpt2_embd={n_embd}_layer={n_layer}_head={n_head}"

        self.n_positions = n_positions
        self.n_dims = n_states
        self._read_in = nn.Linear(n_states + 3, n_embd)
        self._backbone = GPT2Model(configuration)
        self._read_out = nn.Linear(n_embd, n_states)
        self._flatten = nn.Flatten(0, 1)

        for w in self._backbone.wpe.parameters(): # remove positional embedding
            w.data.fill_(0)
        self._backbone.wpe.weight.requires_grad=False

    def forward(self, X):
        embeds = self._read_in(X)
        output = self._backbone(inputs_embeds=embeds).last_hidden_state
        logit = self._read_out(output)[:, :-1]
        logit = self._flatten(logit)
        return logit

In [None]:
dsets_train, actions_train = generate_B(N=N)
dsets_val, actions_val = generate_B(N=N//4)

data_train = BanditDataset(dsets_train, actions_train)
data_val = BanditDataset(dsets_val, actions_val)

train_dataloader = DataLoader(data_train, batch_size=64)
val_dataloader = DataLoader(data_val, batch_size=64)

In [None]:
torch.save(dsets_train, 'dsets_train.pth')
torch.save(dsets_val, 'dsets_val.pth')

In [None]:
torch.save(actions_train, 'actions_train.pth')
torch.save(actions_val, 'actions_val.pth')

In [8]:
dsets_train = torch.load('/kaggle/input/bandit-data1/dsets_train.pth')
dsets_val = torch.load('/kaggle/input/bandit-data1/dsets_val.pth')

In [9]:
actions_train = torch.load('/kaggle/input/bandit-data1/actions_train.pth')
actions_val = torch.load('/kaggle/input/bandit-data1/actions_val.pth')

In [10]:
data_train = BanditDataset(dsets_train, actions_train)
data_val = BanditDataset(dsets_val, actions_val)

train_dataloader = DataLoader(data_train, batch_size=64)
val_dataloader = DataLoader(data_val, batch_size=64)

In [11]:
model = TransformerModel(n_states=A)
model.to(device)
model = nn.DataParallel(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
#optimizer = torch.optim.AdamW([param for param in model.parameters() if param.requires_grad == True])

Train Initial Policy

In [None]:
def train_initial_policy(model, data_loader, optimizer):
    size = len(data_loader.dataset)
    model.train()
    for batch, (X, a) in enumerate(train_dataloader):
        X = X.to(device)
        pred = model(X)
        a = a.flatten().to(device)

        loss_fn = torch.nn.CrossEntropyLoss()
        loss = torch.mean(loss_fn(pred, a))

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [None]:
def test_initial_policy(model, dataloader):
    num_batches = len(dataloader)
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for X, a in dataloader:
            X = X.to(device)
            pred = model(X)
            a = a.flatten().to(device)

            loss_fn = torch.nn.CrossEntropyLoss()
            loss = torch.mean(loss_fn(pred, a))
            val_loss += loss.item()

    val_loss /= num_batches
    print(f"Val loss: {val_loss:>8f} \n")
    return val_loss

In [None]:
epochs = 10
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
cur_val_loss = np.inf
cur_epoch = 1
cur_state_dict = None
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_initial_policy(model, train_dataloader, optimizer)
    if test_initial_policy(model, val_dataloader) < cur_val_loss:
        cur_epoch = t + 1
        cur_state_dict = model.module.state_dict()
        torch.save({
            'epoch': cur_epoch,
            'model_state_dict': cur_state_dict,
            }, 'transformer_model.pt')
print("Done!")

In [None]:
torch.save(model.state_dict(), 'initial_policy.pth')

In [12]:
model.load_state_dict(torch.load('/kaggle/input/bandit-data1/initial_policy.pth'))
model.to(device)

DataParallel(
  (module): TransformerModel(
    (_read_in): Linear(in_features=8, out_features=32, bias=True)
    (_backbone): GPT2Model(
      (wte): Embedding(50257, 32)
      (wpe): Embedding(501, 32)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-3): 4 x GPT2Block(
          (ln_1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    )
    (_read_out): Linear(in_features=32, out_features

In [13]:
def generate(A=5, N=1, n=500):
    mu = np.random.rand(A)

    p_1 = np.random.dirichlet(np.ones(A))
    p_2 = np.zeros(A)
    p_2[np.random.choice(A)] = 1
    w = (np.random.choice(11)) / 10
    p = (1 - w) * p_1 + w * p_2

    a = np.random.choice(A, n, p=p)

    r = np.random.normal(mu[a], 0.3)

    X = np.zeros((n, A + 3), np.float32)
    X[:, [0, -2]] = 1
    X[np.arange(n), a + 1] = 1
    X[:, -1] = r
    return X, mu

In [None]:
reg = np.empty((500, n - 1))

for trial in range(500):
  X, mu = generate()
  X = torch.unsqueeze(torch.from_numpy(X).to(device), 0)
  with torch.no_grad():
    prediction = model.forward(X).cpu().numpy()
  reg[trial] = np.max(mu) - mu[prediction.argmax(1)]

In [None]:
X, mu = generate()
X = torch.unsqueeze(torch.from_numpy(X).to(device), 0)
with torch.no_grad():
  prediction = model.forward(X).cpu().numpy()

Q function: Q(a) = E(Rt|At=a)

In [None]:
def select_action_with_policy(model, X):
    X = torch.unsqueeze(torch.from_numpy(X).to(device), 0)
    with torch.no_grad():
        predictions = model.forward(X).cpu().numpy()
    return predictions

In [66]:
def Q_function(dsets, actions, N=N):
    Q_est = np.zeros((N, A))

    for trial in tqdm(range(N)):
        Qa = np.zeros(A)
        Na = np.zeros(A)

        X = dsets[trial]
        a = actions[trial]

        for i in range(n):
            reward = X[i, -1]
            action = a[i]

            Na[action] += 1
            Qa[action] += reward
        
        Q_est[trial] = np.nan_to_num(Qa / Na)

    return Q_est

In [67]:
# Q_est_train = e_greedy_k_armed_bandit_with_policy(model, dsets_train)
# Q_est_val = e_greedy_k_armed_bandit_with_policy(model, dsets_val, N=N//4)

Q_est_train = Q_function(dsets_train, actions_train)
Q_est_val = Q_function(dsets_val, actions_val, N=N//4)

  Q_est[trial] = np.nan_to_num(Qa / Na)
100%|██████████| 40000/40000 [00:43<00:00, 917.44it/s]
100%|██████████| 10000/10000 [00:10<00:00, 927.75it/s]


In [69]:
torch.save(Q_est_train, 'Q_est_train.pth')
torch.save(Q_est_val, 'Q_est_val.pth')

In [14]:
Q_est_train = torch.load('/kaggle/input/bandit-data1/Q_est_train-2.pth')
Q_est_val = torch.load('/kaggle/input/bandit-data1/Q_est_val-2.pth')

Policy Improvement

In [15]:
init_model = TransformerModel(n_states=A)
init_model.to(device)
init_model = nn.DataParallel(init_model)
init_model.load_state_dict(torch.load('/kaggle/input/bandit-data1/initial_policy.pth'))

<All keys matched successfully>

In [16]:
initial_train_probs = []
with torch.no_grad():
    for X, _ in train_dataloader:
        X = X.to(device)
        probs_log = init_model(X)
        probs = torch.softmax(probs_log, dim=-1)
        initial_train_probs.append(probs.cpu())

In [17]:
initial_val_probs = []
with torch.no_grad():
    for X, _ in val_dataloader:
        X = X.to(device)
        probs_log = init_model(X)
        probs = torch.softmax(probs_log, dim=-1)
        initial_val_probs.append(probs.cpu())

In [32]:
initial_train_probs[0].argmax(1)

tensor([4, 0, 4,  ..., 1, 1, 1])

In [21]:
X, actions = next(iter(train_dataloader))

In [None]:
actions[0] # 64x500

In [22]:
new_probs_log = model(X)
new_probs = torch.softmax(new_probs_log, dim=-1)

In [28]:
new_probs.argmax(1).shape

torch.Size([32000])

In [None]:
# one batch size has 64 trajectories, each trajectory has 500 steps = 64*500 = 30000 , 30000x5
# 30000x1 
# ratios = 30000x1

In [36]:
def ppo_train(model, Q_est, optimizer, old_probs_train, clip_range=0.2):
    model.train()
    size = len(train_dataloader.dataset)
    #old_probs = initial_train_probs
    
    for i, (X, actions) in enumerate(train_dataloader):
        # Forward pass
        #old_probs_log = model(X)
        #old_probs = torch.softmax(old_probs_log, dim=-1)
        old_probs = old_probs_train[i].argmax(1).to(device)
        
        Q_values = np.array([])
        for j in range(64):
            idx = i * 64 + j
            Q_values = np.append(Q_values, Q_est[idx][actions[j]])
        Q_values = torch.tensor(Q_values.reshape(-1, 1))
        
        advantages = (Q_values - Q_values.mean()) / (Q_values.std() + 1e-10)
        advantages = torch.tensor(advantages, dtype=torch.float32).to(device)

        #optimizer.zero_grad()
        new_probs_log = model(X)
        new_probs = torch.softmax(new_probs_log, dim=-1).argmax(1)

        ratios = new_probs / (old_probs + 1e-10)

        surr1 = advantages * ratios
        surr2 = advantages * torch.clamp(ratios, 1-clip_range, 1+clip_range)
        surr_loss = -torch.min(surr1, surr2).mean()
        
        kl_div = torch.distributions.kl_divergence(
            torch.distributions.Categorical(probs=old_probs),
            torch.distributions.Categorical(probs=new_probs)
        ).mean()

        loss = surr_loss + 0.01 * kl_div

        # Take gradient step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        #with torch.no_grad():
        #    updated_probs_log = model(X)
        #    updated_probs = torch.softmax(updated_probs_log, dim=-1)
        old_probs_train[i] = new_probs
        if i % 100 == 0:
            loss, current = loss.item(), (i + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    
    return old_probs_train

In [38]:
def ppo_test(model, Q_est, old_probs_val, clip_range=0.2, kl_coeff=0.01):
    model.eval()
    num_batches = len(val_dataloader)
    val_loss = 0.0

    with torch.no_grad():
        for i, (X, actions) in enumerate(val_dataloader):
            #old_probs_log = model(X)
            #old_probs = torch.softmax(old_probs_log, dim=-1)
            old_probs = old_probs_val[i].argmax(1).to(device)
            
            Q_values = np.array([])
            if i != 10000//64:
                for j in range(64):
                    idx = i * 64 + j
                    if idx >= 10000:
                        break
                    Q_values = np.append(Q_values, Q_est[idx][actions[j]])
            else: # last batch
                for j in range(16):
                    idx = i * 64 + j
                    if idx >= 10000:
                        break
                    Q_values = np.append(Q_values, Q_est[idx][actions[j]])
            Q_values = torch.tensor(Q_values.reshape(-1, 1))
            
            advantages = (Q_values - Q_values.mean()) / (Q_values.std() + 1e-10)
            advantages = torch.tensor(advantages, dtype=torch.float32).to(device)
            
            new_probs_log = model(X)
            new_probs = torch.softmax(new_probs_log, dim=-1).argmax(1)
            
            ratios = new_probs / (old_probs + 1e-10)
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1-clip_range, 1+clip_range) * advantages
            surr_loss = -torch.min(surr1, surr2).mean()
            
            val_loss += surr_loss.item()
            
            old_probs_val[i] = new_probs
            
    val_loss /= num_batches
    
    print(f"Val loss: {val_loss:>8f} \n")
    return val_loss, old_probs_val

In [None]:
cur_epoch

In [39]:
epochs = 10
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
cur_val_loss = np.inf
cur_epoch = 1
cur_state_dict = None
old_probs1 = initial_train_probs
old_probs2 = initial_val_probs

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    old_probs1 = [prob.detach() for prob in old_probs1]
    old_probs1 = ppo_train(model, Q_est_train, optimizer, old_probs1)
    val_loss, old_probs2 = ppo_test(model, Q_est_val, old_probs2)
    if val_loss < cur_val_loss:
        cur_val_loss = val_loss
        cur_epoch = t + 1
        cur_state_dict = model.module.state_dict()
        torch.save({
            'epoch': cur_epoch,
            'model_state_dict': cur_state_dict,
            }, 'transformer_model.pt')
print("Done!")

Epoch 1
-------------------------------


  advantages = torch.tensor(advantages, dtype=torch.float32).to(device)


OutOfMemoryError: CUDA out of memory. Tried to allocate 3.82 GiB (GPU 0; 14.76 GiB total capacity; 10.26 GiB already allocated; 3.69 GiB free; 10.33 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
torch.save({
            'model_state_dict': model.module.state_dict(),
            }, 'transformer_model.pt')

In [None]:
reg = np.empty((500, n - 1))

for trial in range(500):
    X, mu = generate()
    X = torch.unsqueeze(torch.from_numpy(X).to(device), 0)
    with torch.no_grad():
        prediction = model.forward(X).cpu().numpy()
    reg[trial] = np.max(mu) - mu[prediction.argmax(1)]

In [None]:
reg_init = np.empty((500, n - 1))

for trial in range(500):
    X, mu = generate()
    X = torch.unsqueeze(torch.from_numpy(X).to(device), 0)
    with torch.no_grad():
        prediction = init_model.forward(X).cpu().numpy()
    reg_init[trial] = np.max(mu) - mu[prediction.argmax(1)]

In [None]:
plt.plot(np.arange(n - 1), reg_init.mean(0));

In [None]:
reg

In [None]:
plt.plot(np.arange(n - 1), reg.mean(0));