#Atención desnuda.

El mecanismo de atención paso a paso

Notebook original: Jared Ostmeyer

In [1]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading torchmetrics-1.8.2-py3-none-any.whl (983 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m30.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.15.2 torchmetrics-1.8.2


In [2]:
import torchvision
import torch
import torchmetrics

Data Pipeline:

In [3]:

##########################################################################################
# Carga datos del MNIST
##########################################################################################

# Cargar datos de entrenamiento, validación y test del... MNIST!
#
def load_mnist(seed=None, device=torch.device('cpu')):

  # Usamos random? Cuando llamamos decimos que sí.
  #
  generator = torch.Generator(device=device)
  if seed is not None:
    generator.manual_seed(seed)

  # Cargamos el dataset MNIST
  #
  samples_train = torchvision.datasets.MNIST('./', train=True, download=True)
  samples_test = torchvision.datasets.MNIST('./', train=False, download=True)

  # Acomodar atributos y categorías
  #
  xs = samples_train.data.to(device)
  num = xs.shape[0]
  xs = xs.reshape([ num, 28**2, 1 ])
  xs = xs.type(torch.float32)
  ys = samples_train.train_labels.to(device)

  xs_test = samples_test.data.to(device)
  num_test = xs_test.shape[0]
  xs_test = xs_test.reshape([ num_test, 28**2, 1 ])
  xs_test = xs_test.type(torch.float32)
  ys_test = samples_test.test_labels.to(device)

  # train/valid split
  #
  num_train = int(num*5/6)
  num_val = num-num_train

  js = torch.randperm(num, generator=generator)
  js_train = js[:num_train]
  js_val = js[num_train:]

  xs_train = xs[js_train]
  ys_train = ys[js_train]

  xs_val = xs[js_val]
  ys_val = ys[js_val]

  # Normalizamos los atributos
  # Atención: Vamos a usar media y desvío -puntaje Z-
  # No vamos a reescalar [0-1]

  mean = torch.mean(xs_train, axis=0, keepdim=True)
  variance = torch.var(xs_train, axis=0, keepdim=True)

  xs_train = (xs_train-mean)/torch.std(variance+1.0E-8)
  xs_val = (xs_val-mean)/torch.std(variance+1.0E-8)
  xs_test = (xs_test-mean)/torch.std(variance+1.0E-8)

  return xs_train, ys_train, xs_val, ys_val, xs_test, ys_test

In [None]:
##########################################################################################
# Modelo
##########################################################################################

class SelfAttentionModel(torch.nn.Module):
  def __init__(self, num_steps, num_channels, num_outputs, **kwargs):
    super().__init__(**kwargs)

    # Inicializamos las componentes de self-attention
    # Cada peso valdrá entre -1/num_channels**0.5, 1/num_channels**0.5
    self.K = torch.nn.Parameter((2.0*torch.rand(num_channels, num_channels)-1.0)/num_channels**0.5)
    self.Q = torch.nn.Parameter((2.0*torch.rand(num_channels, num_channels)-1.0)/num_channels**0.5)
    self.V = torch.nn.Parameter((2.0*torch.rand(num_channels, num_channels)-1.0)/num_channels**0.5)

    self.softmax = torch.nn.Softmax(dim=1)

    # Inicializamos capa de salida
    #
    self.out = torch.nn.Linear(num_steps*num_channels, num_outputs)

  def forward(self, x):

    batch_size, num_steps, num_channels = x.shape

    # Definimos self-attention
    #
    y = []
    for i in range(batch_size): # Dentro del batch, vamos de a 1. Nada de paralelizar.

      x_i = x[i,:,:] # x_i shape = [ num_steps, num_channels ]

      x_k_i = torch.matmul(x_i, self.K) # x_k_i shape = [ num_steps, num_channels ]
      x_q_i = torch.matmul(x_i, self.Q) # x_q_i shape = [ num_steps, num_channels ]
      x_v_i = torch.matmul(x_i, self.V) # x_v_i shape = [ num_steps, num_channels ]

      w_i = self.softmax(torch.matmul(x_q_i, x_k_i.T)/num_channels**0.5) # w_i shape = [ num_steps, num_steps ]
      y_i = torch.matmul(w_i, x_v_i) # y_i shape = [ num_steps, num_channels ]

      y.append(y_i)
    y = torch.stack(y, axis=0) # y shape = [ batch_size, num_steps, num_channels ]

    # Flatten output
    #
    y_flat = y.reshape([ batch_size, num_steps*num_channels ]) # y_flat shape = [ batch_size, num_steps*num_channels ]

    # Capa de salida
    #
    l = self.out(y_flat) # l shape = [ batch_size, num_outputs ]

    return l

##########################################################################################
# Crear instancia de modelo, métricas y optimizador
##########################################################################################

model = SelfAttentionModel(num_steps=28**2, num_channels=1, num_outputs=10)
probability = torch.nn.Softmax(dim=1)

loss = torch.nn.CrossEntropyLoss()
accuracy = torchmetrics.classification.MulticlassAccuracy(num_classes=10)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

##########################################################################################
# Sampleamos dataset a tensores
##########################################################################################

xs_train, ys_train, xs_val, ys_val, xs_test, ys_test = load_mnist(seed=42)

dataset_train = torch.utils.data.TensorDataset(xs_train, ys_train)
sampler_train = torch.utils.data.RandomSampler(dataset_train, replacement=True)
loader_train = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=16, sampler=sampler_train, drop_last=True)

##########################################################################################
# Modelo
##########################################################################################

i_better = -1
e_better = 1.0e8
a_better = 0.0
state_better = {}

# Loopeamos el dataset por cuántas épocas?
#
for i in range(128):

  # Entrenamos
  #
  model.train()
  e_train = 0.0
  a_train = 0.0
  for xs_batch, ys_batch in iter(loader_train): # Must use `iter` or `enumerate` for efficiency
    ls_batch = model(xs_batch)
    ps_batch = probability(ls_batch) # De logit (negativo de la derivada de la función de entropía) a probabilidades
    e_batch = loss(ls_batch, ys_batch) # CrossEntropyLoss quiere logits
    a_batch = accuracy(ps_batch, ys_batch)
    optimizer.zero_grad()
    e_batch.backward()
    optimizer.step()
    e_train += e_batch.detach()/len(loader_train) # error promedio por época
    a_train += a_batch.detach()/len(loader_train) # accuracy promedio por época

  # Vemos si le fue bien en validación
  #
  model.eval()
  with torch.no_grad():
    ls_val = model(xs_val)
    ps_val = probability(ls_val) # De logit a probabilidades
    e_val = loss(ls_val, ys_val) # CrossEntropyLoss quiere logits
    a_val = accuracy(ps_val, ys_val)
    if e_val < e_better: # Graba la mejor epoch
      i_better = i
      e_better = e_val
      a_better = a_val
      state_better = model.state_dict()

  # Reporte
  #
  print(
    'i: '+str(i),
    'e_train: {:.5f}'.format(float(e_train)/0.693)+' bits',
    'a_train: {:.1f}'.format(100.0*float(a_train))+' %',
    'e_val: {:.5f}'.format(float(e_val)/0.693)+' bits',
    'a_val: {:.1f}'.format(100.0*float(a_val))+' %',
    sep='\t', flush=True
  )

model.eval()
model.load_state_dict(state_better)
with torch.no_grad():
  ls_test = model(xs_test)
  ps_test = probability(ls_test) # De logit a probabilidades
  e_test = loss(ls_test, ys_test) # CrossEntropyLoss quiere logits
  a_test = accuracy(ps_test, ys_test)

print(
  'e_test: {:.5f}'.format(float(e_test)/0.693)+' bits',
  'a_test: {:.1f}'.format(100.0*float(a_test))+' %',
  sep='\t', flush=True
)



100%|██████████| 9.91M/9.91M [00:00<00:00, 12.4MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 341kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.34MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.62MB/s]


i: 0	e_train: 2.99019 bits	a_train: 20.4 %	e_val: 2.98537 bits	a_val: 21.7 %
i: 1	e_train: 2.23472 bits	a_train: 42.0 %	e_val: 0.59076 bits	a_val: 88.6 %
i: 2	e_train: 0.52510 bits	a_train: 87.5 %	e_val: 0.50759 bits	a_val: 90.3 %
i: 3	e_train: 0.46751 bits	a_train: 88.8 %	e_val: 0.47586 bits	a_val: 91.2 %
i: 4	e_train: 0.44340 bits	a_train: 89.3 %	e_val: 0.46530 bits	a_val: 91.5 %
i: 5	e_train: 0.43390 bits	a_train: 89.7 %	e_val: 0.45716 bits	a_val: 91.3 %
i: 6	e_train: 0.43049 bits	a_train: 89.9 %	e_val: 0.46653 bits	a_val: 91.1 %
i: 7	e_train: 0.40598 bits	a_train: 90.3 %	e_val: 0.45990 bits	a_val: 91.3 %
i: 8	e_train: 0.40517 bits	a_train: 90.4 %	e_val: 0.47208 bits	a_val: 91.2 %
i: 9	e_train: 0.39831 bits	a_train: 90.5 %	e_val: 0.44770 bits	a_val: 91.7 %
i: 10	e_train: 0.38581 bits	a_train: 90.9 %	e_val: 0.46154 bits	a_val: 91.5 %
i: 11	e_train: 0.38053 bits	a_train: 90.9 %	e_val: 0.46032 bits	a_val: 91.5 %
i: 12	e_train: 0.39180 bits	a_train: 90.8 %	e_val: 0.45194 bits	a_val: 91.