In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import time
from opt_einsum import contract
import opt_einsum

In [3]:
t = torch.tensor([[[1, 2], [2, 3]], [[4, 5], [6, 7]]])
index = torch.tensor([0, 1])
print(t[[0, 1], index, :])
# torch.gather(t, 1, index)
# torch.scatter(t, 1, index, t)

tensor([[1, 2],
        [6, 7]])


In [4]:
%load_ext line_profiler

In [5]:
BATCH_SIZE = 5

def assert_shape(x, shape):
  if x.shape != shape:
    raise AssertionError('invalid shape; got {} but expected {}.'.format(x.shape, shape))

class DenseFF(nn.Module):
  def __init__(self, d_model, d_ff):
    super(DenseFF, self).__init__()
    self.f1 = nn.Linear(d_model, d_ff)
    self.f2 = nn.Linear(d_ff, d_model)

  def forward(self, x):
    ff = self.f1(x)
    ff = F.relu(ff)
    out = self.f2(ff)
    return out

class DenseFFEinsum(nn.Module):
  def __init__(self, d_model, d_ff):
    super(DenseFFEinsum, self).__init__()

    self.d_model = d_model
    self.d_ff = d_ff

    self.f1 = nn.Parameter(torch.Tensor(d_ff, d_model))
    self.f2 = nn.Parameter(torch.Tensor(d_ff, d_model))

  def forward(self, x):
    inner = torch.einsum('bm,nm->bn', x, self.f1)
    inner = F.relu(inner)
    output = torch.einsum('bn,nm->bm', inner, self.f2)
    assert_shape(output, (BATCH_SIZE, self.d_model))
    return output

CONTRACTS = dict()

class LowRank(nn.Module):
  def __init__(self, d_model, d_lowrank, d_output=None):
    super(LowRank, self).__init__()
    if d_output is None:
      d_output = d_model
    
    self.f1 = nn.Parameter(torch.Tensor(d_model, d_lowrank))
    self.f2 = nn.Parameter(torch.Tensor(d_lowrank, d_output))
    # self.contract =

  def forward(self, x):
    out = torch.einsum('bm,ml,lo->bo', x, self.f1, self.f2)
    # out = contract('bm,ml,lo->bo', x, self.f1, self.f2)
    # lowrank = torch.einsum('bm,ml->bl', x, self.f1)
    # out = torch.einsum('bl,lo->bo', lowrank, self.f2)
    return out


def stop_gradient(x):
  return x.detach()


class GradientsLike(nn.Module):
  def __init__(self):
    super(GradientsLike, self).__init__()

  def forward(self, x):
    return x - stop_gradient(x)


class SparseController(nn.Module):
  def __init__(self, d_model, d_lowrank, d_ff, N):
    super(SparseController, self).__init__()
    assert d_ff % N == 0
    self.lowrank = LowRank(d_model, d_lowrank, d_ff)
    self.N = N
    self.d_model = d_model
    self.d_ff = d_ff
    self.d_lowrank = d_lowrank

  def forward(self, x):
    N = self.N
    assert_shape(x, (BATCH_SIZE, self.d_model))
    out = self.lowrank(x)

    out = out.view(BATCH_SIZE, -1, N)
    assert out.shape == (BATCH_SIZE, self.d_ff//N, N)

    # probs = F.softmax(out, dim=-1)
    # TODO(jaszczur): change to discrete
    # result = probs

    result = out
    assert result.shape == (BATCH_SIZE, self.d_ff//N, N)
    return result

class SparseFF(nn.Module):
  def __init__(self, d_model, d_ff, d_lowrank, N):
    super(SparseFF, self).__init__()
    assert d_ff % N == 0

    n_expertsets = d_ff // N

    self.d_model = d_model
    self.d_ff = d_ff
    self.d_lowrank = d_lowrank
    self.N = N
    self.controller = SparseController(d_model, d_lowrank, d_ff, N)

    self.f1 = nn.Parameter(torch.Tensor(n_expertsets, N, d_model))
    # TODO(jaszczur): add biases
    # self.f1 = nn.Linear(d_model, d_ff)
    # self.f2 = nn.Linear(d_ff, d_model)
    self.f2 = nn.Parameter(torch.Tensor(n_expertsets, N, d_model))

  def forward(self, x):
    N = self.N
    assert x.shape == (BATCH_SIZE, self.d_model)
    controller_output = self.controller(x)
    if self.training:
      inner = torch.einsum('bm,enm->ben', x, self.f1)
      # inner = self.f1(x)
      # inner = inner.view(BATCH_SIZE, self.d_ff//N, N)

      assert_shape(inner, (BATCH_SIZE, self.d_ff//N, N))
      assert_shape(controller_output, (BATCH_SIZE, self.d_ff//N, N))
      inner = F.relu(inner) * controller_output

      output = torch.einsum('ben,enm->bm', inner, self.f2)
      # inner = inner.view(BATCH_SIZE, self.d_ff)
      # output = self.f2(inner)
      assert_shape(output, (BATCH_SIZE, self.d_model))
      return output
    else:
      controller_indexes = torch.argmax(controller_output, dim=-1, keepdim=True)
      
      assert BATCH_SIZE == 1
      assert_shape(controller_indexes, (BATCH_SIZE, self.d_ff//N, 1))
      controller_indexes = controller_indexes.view(self.d_ff//N)
      assert_shape(self.f1, (self.d_ff//N, N, self.d_model))

      rangeE = torch.arange(self.d_ff//N)

      f1p = self.f1[rangeE, controller_indexes]
      f2p = self.f2[rangeE, controller_indexes]
      # f1p = torch.index_select(self.f1, -1, controller_indexes)
      # f2p = torch.index_select(self.f2, -1, controller_indexes)

      assert_shape(f1p, (self.d_ff//N, self.d_model))
      assert_shape(f2p, (self.d_ff//N, self.d_model))

      inner = torch.einsum('bm,em->be', x, f1p)

      assert_shape(inner, (BATCH_SIZE, self.d_ff//N))

      inner = F.relu(inner)
      output = torch.einsum('be,em->bm', inner, f2p)
      assert_shape(output, (BATCH_SIZE, self.d_model))
      return output


class Residual(nn.Module):
  def __init__(self, layer):
    super(Residual, self).__init__()
    self.fflayer = layer

  def forward(self, x):
    return self.fflayer(x) + x


class Model(nn.Module):
  def __init__(self, layers, d_model, d_ff, d_lowrank, sparsity, version):
    super(Model, self).__init__()
    if 'sparse' in version:
      layer_fun = lambda: SparseFF(d_model, d_ff, d_lowrank, sparsity)
    elif 'einsum' in version:
      layer_fun = lambda: DenseFFEinsum(d_model, d_ff)
    else:
      layer_fun = lambda: DenseFF(d_model, d_ff)
    self.layers = nn.ModuleList(
        [Residual(layer_fun())
         for i in range(layers)])
    
  def forward(self, x):
    for layer in self.layers:
      x = layer(x)
    return x

In [16]:
# model = Model(3, 128, 4*128, 32, 16, 'sparse')

CUDA = torch.device("cuda")

def timemodel(batch, sample, layers, d_model, d_ff, d_lowrank, sparsity, version):
  model = Model(layers, d_model, d_ff, d_lowrank, sparsity, version)
  # model.to(CUDA)
  sample = [torch.Tensor(np.random.random((batch, d_model)))]
  if 'train' in version:
    model.train()
  else:
    model.eval()
  start = time.time()
  with torch.no_grad():
    for i in range(REPEAT):
      for s in sample:
        r = model(s)
  return time.time() - start

In [7]:
BATCH_SIZE = 1
SAMPLE = 100
REPEAT = 100
LAYERS = 20
DMODEL = 1024
DFF = 4 * 1024
DLOWRANK = 32
SPARSITY = 1024

In [17]:
print(torch.cuda.mem_get_info())
torch.cuda.empty_cache()
print(torch.cuda.mem_get_info())
import gc
gc.collect()
torch.cuda.empty_cache()
print(torch.cuda.mem_get_info())

(2087911424, 4294705152)
(2773680128, 4294705152)
(2773680128, 4294705152)


In [19]:
%lprun -f SparseFF.forward print("sparse-eval", timemodel(BATCH_SIZE, SAMPLE, LAYERS, DMODEL, DFF, DLOWRANK, SPARSITY, "sparse-eval"))

sparse-eval 0.6223287582397461


In [10]:
%lprun -f LowRank.forward print("sparse-eval", timemodel(BATCH_SIZE, SAMPLE, LAYERS, DMODEL, DFF, DLOWRANK, SPARSITY, "sparse-eval"))

sparse-eval 0.7631902694702148


In [11]:
%lprun -f SparseFF.forward print("sparse-train", timemodel(BATCH_SIZE, SAMPLE, LAYERS, DMODEL, DFF, DLOWRANK, SPARSITY, "sparse-train"))

sparse-train 0.6327347755432129


In [20]:
%lprun -f DenseFFEinsum.forward  print("dense-einsum", timemodel(BATCH_SIZE, SAMPLE, LAYERS, DMODEL, DFF, DLOWRANK, SPARSITY, "dense-einsum"))

dense-einsum 3.8361973762512207


In [13]:
BATCH_SIZE = 1
timemodel(1, 1000, 16, 1024, 4*1024, 32, 64, True, eval=False)

TypeError: timemodel() got an unexpected keyword argument 'eval'

In [None]:
timemodel(1, 1000, 16, 1024, 4*1024, 32, 64, False)

In [None]:
# class ResNet(nn.Module):
#   def __init__(self):
#     super(ResNet, self).__init__()
#     # After flattening an image of size 28x28 we have 784 inputs

#     d_model = 128
#     d_ff = 256
#     num_layers = 3

#     d_lowrank = 16
#     N = 4

#     if USE_SPARSE:
#       fff = lambda: SparseFF(d_model, d_ff, d_lowrank, N)
#     else:
#       fff = lambda: DenseFF(d_model, d_ff)
    

#     self.fc1 = nn.Linear(784, d_model)
#     self.layers = nn.ModuleList(
#         [Residual(fff()) for i in range(num_layers)])
#     self.output = nn.Linear(d_model, 10)

#   def forward(self, x):
#     assert x.shape == (BATCH_SIZE, 1, 28, 28)
#     x = torch.flatten(x, 1)
#     assert x.shape == (BATCH_SIZE, 28*28)
#     x = self.fc1(x)
#     for layer in self.layers:
#       x = layer(x)
#     x = self.output(x)
#     output = F.log_softmax(x, dim=1)
#     return output

In [None]:
# class Net(nn.Module):
#     def __init__(self):
#         super(Net, self).__init__()
#         # After flattening an image of size 28x28 we have 784 inputs
#         self.fc1 = nn.Linear(784, 128)
#         self.fc2 = nn.Linear(128, 128)
#         self.fc3 = nn.Linear(128, 10)

#     def forward(self, x):
#         x = torch.flatten(x, 1)
#         x = self.fc1(x)
#         x = F.relu(x)
#         x = self.fc2(x)
#         x = F.relu(x)
#         x = self.fc3(x)
#         output = F.log_softmax(x, dim=1)
#         return output


# def train(model, device, train_loader, optimizer, epoch, log_interval):
#     model.train()
#     for batch_idx, (data, target) in enumerate(train_loader):
#       # print(data.shape)
#       assert data.shape == (BATCH_SIZE, 1, 28, 28)  # B, 1(C/F), H, W
#       data, target = data.to(device), target.to(device)
#       optimizer.zero_grad()
#       output = model(data)
#       loss = F.nll_loss(output, target)
#       loss.backward()
#       optimizer.step()
#       if batch_idx % log_interval == 0:
#           print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
#               epoch, batch_idx * len(data), len(train_loader.dataset),
#               100. * batch_idx / len(train_loader), loss.item()))


# def test(model, device, test_loader):
#     model.eval()
#     test_loss = 0
#     correct = 0
#     with torch.no_grad():
#         for data, target in test_loader:
#             data, target = data.to(device), target.to(device)
#             output = model(data)
#             test_loss += F.nll_loss(output, target, reduction="sum").item()  # sum up batch loss
#             pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
#             correct += pred.eq(target.view_as(pred)).sum().item()

#     test_loss /= len(test_loader.dataset)

#     print("\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
#         test_loss, correct, len(test_loader.dataset),
#         100. * correct / len(test_loader.dataset)))

In [None]:
# BATCH_SIZE = 128
# test_BATCH_SIZE = 1000
# epochs = 2
# lr = 1e-3
# # use_cuda = False
# seed = 1
# log_interval = 10000

# use_cuda = torch.cuda.is_available()

# torch.manual_seed(seed)
# device = torch.device("cuda" if use_cuda else "cpu")

# train_kwargs = {"BATCH_SIZE": BATCH_SIZE}
# test_kwargs = {"BATCH_SIZE": test_BATCH_SIZE}
# if use_cuda:
#     cuda_kwargs = {"num_workers": 1,
#                     "pin_memory": True,
#                     "shuffle": True}
#     train_kwargs.update(cuda_kwargs)
#     test_kwargs.update(cuda_kwargs)

In [None]:
# transform=transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.1307,), (0.3081,))
#     ])
# dataset1 = datasets.MNIST("../data", train=True, download=True,
#                     transform=transform)
# dataset2 = datasets.MNIST("../data", train=False,
#                     transform=transform)
# train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
# test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

In [None]:
# USE_SPARSE = False

# model = ResNet().to(device)
# optimizer = optim.Adam(model.parameters(), lr=lr)

# for epoch in range(1, epochs + 1):
#     train(model, device, train_loader, optimizer, epoch, log_interval)
#     test(model, device, test_loader)