# Batch Grouping using asym

Make sure you have a GPU runtime

In [1]:
! nvidia-smi

Wed Mar  9 10:56:04 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   40C    P0    29W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Installing asym

In [None]:
! git clone https://github.com/daeseoklee/asym
! mv asym tmp 
! mv tmp/asym asym

fatal: destination path 'asym' already exists and is not an empty directory.


In [None]:
from numpy.random import choice
import torch 
import torch.nn as nn
from torch.nn.functional import pad
from torch.utils.data import DataLoader
from asym.annotated_module import AnnotatedModule
from asym.data_collection import DataCollection
from asym.grouper import LengthThresholdGrouper, UniGrouper
from asym.precompute_grouper_thresholds import find_thresholds
from tqdm import tqdm

## Creating a Dataset

Let's create a peculiar virtual dataset, where 
* Datapoints have input length in {1, 4, 16, 64}
* Each length appears with the frequency that is inversely proportional to the value
* Each input is associated with a binary label

In [None]:
dataset_size = 1024

length_distribution = {4 ** i: 4 ** (-i) for i in range(4)}
total_weight = sum(length_distribution.values())
length_distribution = {num: weight / total_weight for num, weight in length_distribution.items()}
lengths = choice(list(length_distribution.keys()), size=dataset_size, replace=True, p=list(length_distribution.values())).tolist()

dataset = [(torch.randint(0, 8, (length,)), choice([0, 1])) for length in lengths]

print(f'{dataset_size} datapoints')
count = 0 
for i, data in enumerate(dataset):
  print(f'{i}-th datapoint:')
  print(f'\t{data}')
  if len(data[0]) >= 2:
    count += 1 
    if count == 2:
      break
print('.\n.\n.')

1024 datapoints
0-th datapoint:
	(tensor([3]), 0)
1-th datapoint:
	(tensor([3]), 0)
2-th datapoint:
	(tensor([2]), 0)
3-th datapoint:
	(tensor([0, 5, 4, 0]), 1)
4-th datapoint:
	(tensor([2]), 0)
5-th datapoint:
	(tensor([3, 5, 4, 3]), 0)
.
.
.


## Standard way of training

Our model consists of embedding -> linear map -> max pooling (ignoring padded positions) -> final linear map 

In [None]:
class AModule(nn.Module):
  def __init__(self, emb_dim=1024, hidden_dim=1024, device='cuda:0'):
    super().__init__()
    self.emb = nn.Embedding(8, emb_dim, device=device) #8 token types, dimension: emb_dim
    self.linear = nn.Linear(emb_dim, hidden_dim, device=device)
    self.proj = nn.Linear(hidden_dim, 1, device=device)
  def forward(self, d):
    x = d['x']
    mask = d['mask']
    assert x.shape == mask.shape 
    assert len(x.shape) == 2
    x = self.emb(x)
    x = self.linear(x) #(batch, length, )
    minus_inf = torch.tensor(-float('inf'), device=x.device)
    x = torch.where(mask[:, :, None].broadcast_to(x.shape), x, minus_inf) #Ignoring meaningless padded values through the next torch.max application.  
    x = torch.max(x, dim=1).values
    return self.proj(x)[:, 0]

We use the standard padding strategy, while remembering the padded positions.

In [None]:
def standard_collate_fn(data_list, target_device='cuda:0'):
  xs = [x.to(device=target_device) for x, _ in data_list]
  ys = [y for _, y in data_list]
  max_len = max(len(x) for x in xs)
  batch = {
      'x': torch.stack([pad(x, (0, max_len - len(x))) for x in xs], dim=0),
      'mask': torch.stack([pad(torch.ones_like(x, dtype=torch.bool), (0, max_len - len(x))) for x in xs], dim=0)
  }
  y = torch.tensor(ys, dtype=torch.float32, device=target_device)
  return batch, y

Usual training proceduere:

In [None]:
loader = DataLoader(dataset, batch_size=256, collate_fn=lambda batch: standard_collate_fn(batch, target_device='cuda:0'))
model = AModule(emb_dim=8192, hidden_dim=8192, device='cuda:0')
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
def train_standard():
  optimizer.zero_grad()
  for i, (batch, y) in enumerate(tqdm(loader)):
    out = model(batch)
    mse = torch.mean((y - out) ** 2)
    mse.backward()
    optimizer.step() 
    optimizer.zero_grad()
%time train_standard()

100%|██████████| 4/4 [00:03<00:00,  1.22it/s]

CPU times: user 3.28 s, sys: 7.48 ms, total: 3.29 s
Wall time: 3.29 s





In [None]:
!nvidia-smi | grep MiB

| N/A   57C    P0    47W / 250W |   4353MiB / 16280MiB |    100%      Default |


Check 
1. The time spent in training
2. The maximum GPU usage so far. 

## Training with Batch Grouping

Length threshold values are computed here from the global dataset-level statistics via a heuristic algorithm. The "cost_fn" is linear, since the preceding `AModule`'s memory usage would be linearly proportional to the length of the batch (which is the maximum sequence length of the pre-batch datapoints).

In [None]:
lengths = [len(data[0]) for data in dataset]
length_thresholds = find_thresholds(lengths, cost_fn=lambda x: x, k=3, num_trials=100)
print('thresholds:', length_thresholds)

thresholds: [1, 4, 16]


We have to slightly modify the AModule class definition, to use it in asym 
1. Inherit from `AnnotatedModule` instead of `nn.Module`
2. Specify the input and output shapes. ('b' stand for batch dimension, and 'l' stands for length dimension)
3. Specify the form of the mask tensor, if you want to pass it to a `forward()` argument. Here, `['seq']` means you want the mask tensor to be of the form `'(b, l_seq)'`. 

In [None]:
class AnnotatedAModule(AnnotatedModule):
  def __init__(self, emb_dim=1024, hidden_dim=1024, device='cuda:0'):
    super().__init__()
    self.emb = nn.Embedding(8, emb_dim, device=device) #8 token types, dimension: emb_dim
    self.linear = nn.Linear(emb_dim, hidden_dim, device=device)
    self.proj = nn.Linear(hidden_dim, 1, device=device)
  def forward(self, x, mask):
    assert x.shape == mask.shape 
    assert len(x.shape) == 2
    x = self.emb(x)
    x = self.linear(x) #(batch, length, )
    minus_inf = torch.tensor(-float('inf'), device=x.device)
    x = torch.where(mask[:, :, None].broadcast_to(x.shape), x, minus_inf) #Ignoring meaningless padded values through the next torch.max application.  
    x = torch.max(x, dim=1).values
    return self.proj(x)[:, 0]
  def get_mask_hint(self):
    return ['seq']
  def get_input_annot(self):
    return '(b, l_seq)'
  def get_output_annot(self):
    return '(b)'


Here's our new collate_fn. `(B, L_sequence)` specifies the shape of the data. The label `'sequence'` will be recognized by `LengthThresholdGrouper` later on.

In [None]:
def new_collate_fn(data_list, target_device='cuda:0'):
  xs = [x.to(device=target_device) for x, _ in data_list]
  ys = [y for _, y in data_list]
  dc = DataCollection('(B, L_sequence)', xs)
  y = torch.tensor(ys, dtype=torch.float32, device=target_device)
  return dc, y 

Modified training procedure:

In [None]:
new_loader = DataLoader(dataset, batch_size=256, collate_fn=lambda batch: new_collate_fn(batch, target_device='cuda:0'))
new_model = AnnotatedAModule(emb_dim=8192, hidden_dim=8192, device='cuda:0')
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
def train_new():
  optimizer.zero_grad()
  for i, (batch, y) in enumerate(tqdm(new_loader)):
    grouper = LengthThresholdGrouper('sequence', length_thresholds) #batch grouping 
    batch.group(grouper=grouper)
    out = batch.apply(new_model) 
    out.regroup(grouper=UniGrouper()) #forming the usual minibatch of outputs
    out = out.data_groups[0].value
    mse = torch.mean((y - out) ** 2) 
    mse.backward()
    optimizer.step() 
    optimizer.zero_grad()
%time train_new()

100%|██████████| 4/4 [00:00<00:00,  7.16it/s]

CPU times: user 560 ms, sys: 2.42 ms, total: 562 ms
Wall time: 567 ms





In [None]:
!nvidia-smi | grep MiB

| N/A   56C    P0    46W / 250W |   4353MiB / 16280MiB |     46%      Default |


Observe the improved efficiency. 