In [46]:
"""PyTorch OpenAI GPT-2 model."""


import logging
import math
import os

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn.functional import gelu

from transformers.configuration_gpt2 import GPT2Config
from transformers.file_utils import add_start_docstrings
from transformers.modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer


logger = logging.getLogger(__name__)

GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {
    "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
    "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin",
    "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin",
    "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-pytorch_model.bin",
    "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-pytorch_model.bin",
}


def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
  """ Load tf checkpoints in a pytorch model
  """
  try:
    import re
    import tensorflow as tf
  except ImportError:
    logger.error(
        "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
        "https://www.tensorflow.org/install/ for installation instructions."
    )
    raise
  tf_path = os.path.abspath(gpt2_checkpoint_path)
  logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
  # Load weights from TF model
  init_vars = tf.train.list_variables(tf_path)
  names = []
  arrays = []
  for name, shape in init_vars:
    logger.info("Loading TF weight {} with shape {}".format(name, shape))
    array = tf.train.load_variable(tf_path, name)
    names.append(name)
    arrays.append(array.squeeze())

  for name, array in zip(names, arrays):
    name = name[6:]  # skip "model/"
    name = name.split("/")
    pointer = model
    for m_name in name:
      if re.fullmatch(r"[A-Za-z]+\d+", m_name):
        scope_names = re.split(r"(\d+)", m_name)
      else:
        scope_names = [m_name]
      if scope_names[0] == "w" or scope_names[0] == "g":
        pointer = getattr(pointer, "weight")
      elif scope_names[0] == "b":
        pointer = getattr(pointer, "bias")
      elif scope_names[0] == "wpe" or scope_names[0] == "wte":
        pointer = getattr(pointer, scope_names[0])
        pointer = getattr(pointer, "weight")
      else:
        pointer = getattr(pointer, scope_names[0])
      if len(scope_names) >= 2:
        num = int(scope_names[1])
        pointer = pointer[num]
    try:
      assert pointer.shape == array.shape
    except AssertionError as e:
      e.args += (pointer.shape, array.shape)
      raise
    logger.info("Initialize PyTorch weight {}".format(name))
    pointer.data = torch.from_numpy(array)
  return model


class Attention(nn.Module):
  # 속성을 정의함, 입력으로 nx, n_ctx, config, scale 적용 여부를 받음
  def __init__(self, nx, n_ctx, config, scale=False):
    # 파이토치 클래스를 사용하기 위한 일반적인 문법
    super().__init__()
    # 첫번째 속성 : 객체는 output_attentions 속성에 config의 output_attentions을 담는다.
    self.output_attentions = config.output_attentions
    
    # 속성 n_state에 nx를 담는다. 여기서 n_state는 768차원이다.
    n_state = nx  # in Attention: n_state=768 (nx=n_embd)
    # [switch nx => n_state from Block to Attention to keep identical to TF implem]
    # n_state를 config.n_head로 나누었을 때 나머지가 0이 되는지 확인한다.
    assert n_state % config.n_head == 0
    # 속성 register_buffer("bias", torch.tril(768, 768), view(1,1,768,768))
    # 레지스터버퍼 함수는 name과 tensor를 입력으로 받는다. 
    self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
    # n_head에 config.n_head를 담는다.
    self.n_head = config.n_head
    # 스플릿사이즈에 n_state 값을 담는다.

    self.split_size = n_state
    self.scale = scale
    # c_attn의 w매트릭스는 768 * 3 , 768
    self.c_attn = Conv1D(n_state * 3, nx)
    # c_proj의 w매트릭스는 768, 768
    self.c_proj = Conv1D(n_state, nx)
    # 드랍아웃
    self.attn_dropout = nn.Dropout(config.attn_pdrop)
    self.resid_dropout = nn.Dropout(config.resid_pdrop)
    # pruned_heads는 집합 형식
    self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        mask = torch.ones(self.n_head, self.split_size // self.n_head)
        # Convert to set and emove already pruned heads
        heads = set(heads) - self.pruned_heads
        for head in heads:
      # Compute how many pruned heads are before the head and move the index accordingly
      head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
      mask[head] = 0
    mask = mask.view(-1).contiguous().eq(1)
    index = torch.arange(len(mask))[mask].long()
    index_attn = torch.cat(
        [index, index + self.split_size, index + (2 * self.split_size)])

    # Prune conv1d layers
    self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
    self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)

    # Update hyper params
    self.split_size = (self.split_size // self.n_head) * \
        (self.n_head - len(heads))
    self.n_head = self.n_head - len(heads)
    self.pruned_heads = self.pruned_heads.union(heads)

  # self-attenrion 함수
  def _attn(self, q, k, v, attention_mask=None, head_mask=None):
    w = q * k
    # b에 bias를 담는데 적당하게 담도록 계산하는 부분 
    b = 0.1
    # w에 담는다. w * b - 10000 * (1 - b)
    #w = w * b - 1e4 * (1 - b)

    if attention_mask is not None:
      # Apply the attention mask
      w = w + attention_mask

    #w = nn.Softmax(w)

    # Mask heads if we want to
    if head_mask is not None:
      w = w * head_mask

    outputs = w * v
    return outputs

In [47]:
x = Attention

In [48]:
print(x)

<class '__main__.Attention'>


In [49]:
x = x._attn(1, 2, 3, 4)

In [50]:
print(x)

24


In [55]:
q = 3
k = 4
v = 5
b= 0.5
w = q * k
w = w * b
print(w)
p = 1e3 * (1 - b)
print(b)
print(p)
output = w * b - p
print(output)

6.0
0.5
500.0
-497.0


In [95]:
x = torch.randn([1,1,768,768])
print(x)
x = x.permute(0, 2, 1, 3)

tensor([[[[ 1.3781,  1.1454,  0.6767,  ...,  0.3131,  0.2286, -1.3658],
          [-0.6200, -0.2225,  0.0104,  ..., -1.4961, -0.3580, -1.0580],
          [ 1.1362,  0.6315,  0.2823,  ...,  0.4099, -0.7750, -0.4081],
          ...,
          [-0.3098,  0.0504,  0.2575,  ...,  0.8729, -0.4162, -0.7146],
          [-1.1531, -0.6457, -0.7778,  ...,  1.3726, -0.6686,  0.2975],
          [-0.2236, -1.4529, -0.0756,  ..., -0.9635, -1.7315, -0.4358]]]])


In [96]:
print(x)

tensor([[[[ 1.3781,  1.1454,  0.6767,  ...,  0.3131,  0.2286, -1.3658]],

         [[-0.6200, -0.2225,  0.0104,  ..., -1.4961, -0.3580, -1.0580]],

         [[ 1.1362,  0.6315,  0.2823,  ...,  0.4099, -0.7750, -0.4081]],

         ...,

         [[-0.3098,  0.0504,  0.2575,  ...,  0.8729, -0.4162, -0.7146]],

         [[-1.1531, -0.6457, -0.7778,  ...,  1.3726, -0.6686,  0.2975]],

         [[-0.2236, -1.4529, -0.0756,  ..., -0.9635, -1.7315, -0.4358]]]])


In [97]:
x = x.contiguous()

In [107]:
print(x)

tensor([[[[ 1.3781,  1.1454,  0.6767,  ...,  0.3131,  0.2286, -1.3658]],

         [[-0.6200, -0.2225,  0.0104,  ..., -1.4961, -0.3580, -1.0580]],

         [[ 1.1362,  0.6315,  0.2823,  ...,  0.4099, -0.7750, -0.4081]],

         ...,

         [[-0.3098,  0.0504,  0.2575,  ...,  0.8729, -0.4162, -0.7146]],

         [[-1.1531, -0.6457, -0.7778,  ...,  1.3726, -0.6686,  0.2975]],

         [[-0.2236, -1.4529, -0.0756,  ..., -0.9635, -1.7315, -0.4358]]]])


In [108]:
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)

In [109]:
print(new_x_shape)

torch.Size([1, 768, 768])


In [110]:
x_size = x.size()[:2]

In [111]:
print(x_size)

torch.Size([1, 768])


In [112]:
x_size2 = x.size(-2)

In [113]:
print(x_size2)

1


In [114]:
x_size3 = x.size(-1)
print(x_size3)

768


In [115]:
r = x.view(*new_x_shape)
print(r)

tensor([[[ 1.3781,  1.1454,  0.6767,  ...,  0.3131,  0.2286, -1.3658],
         [-0.6200, -0.2225,  0.0104,  ..., -1.4961, -0.3580, -1.0580],
         [ 1.1362,  0.6315,  0.2823,  ...,  0.4099, -0.7750, -0.4081],
         ...,
         [-0.3098,  0.0504,  0.2575,  ...,  0.8729, -0.4162, -0.7146],
         [-1.1531, -0.6457, -0.7778,  ...,  1.3726, -0.6686,  0.2975],
         [-0.2236, -1.4529, -0.0756,  ..., -0.9635, -1.7315, -0.4358]]])


In [116]:
shape = x.size()[:-1]

In [117]:
print(shape)

torch.Size([1, 768, 1])


In [118]:
shape = shape + (12, x.size(-1) // 12)

In [119]:
print(shape)

torch.Size([1, 768, 1, 12, 64])


In [120]:
x = x.view(*new_x_shape)

In [125]:
print(x)
print(x.size())

tensor([[[ 1.3781,  1.1454,  0.6767,  ...,  0.3131,  0.2286, -1.3658],
         [-0.6200, -0.2225,  0.0104,  ..., -1.4961, -0.3580, -1.0580],
         [ 1.1362,  0.6315,  0.2823,  ...,  0.4099, -0.7750, -0.4081],
         ...,
         [-0.3098,  0.0504,  0.2575,  ...,  0.8729, -0.4162, -0.7146],
         [-1.1531, -0.6457, -0.7778,  ...,  1.3726, -0.6686,  0.2975],
         [-0.2236, -1.4529, -0.0756,  ..., -0.9635, -1.7315, -0.4358]]])
torch.Size([1, 768, 768])


In [127]:
x_t = x.permute(2, 0, 1)
x_y = x.permute(2, 1, 0)

In [128]:
print(x_t)

tensor([[[ 1.3781, -0.6200,  1.1362,  ..., -0.3098, -1.1531, -0.2236]],

        [[ 1.1454, -0.2225,  0.6315,  ...,  0.0504, -0.6457, -1.4529]],

        [[ 0.6767,  0.0104,  0.2823,  ...,  0.2575, -0.7778, -0.0756]],

        ...,

        [[ 0.3131, -1.4961,  0.4099,  ...,  0.8729,  1.3726, -0.9635]],

        [[ 0.2286, -0.3580, -0.7750,  ..., -0.4162, -0.6686, -1.7315]],

        [[-1.3658, -1.0580, -0.4081,  ..., -0.7146,  0.2975, -0.4358]]])


In [129]:
print(x_y)

tensor([[[ 1.3781],
         [-0.6200],
         [ 1.1362],
         ...,
         [-0.3098],
         [-1.1531],
         [-0.2236]],

        [[ 1.1454],
         [-0.2225],
         [ 0.6315],
         ...,
         [ 0.0504],
         [-0.6457],
         [-1.4529]],

        [[ 0.6767],
         [ 0.0104],
         [ 0.2823],
         ...,
         [ 0.2575],
         [-0.7778],
         [-0.0756]],

        ...,

        [[ 0.3131],
         [-1.4961],
         [ 0.4099],
         ...,
         [ 0.8729],
         [ 1.3726],
         [-0.9635]],

        [[ 0.2286],
         [-0.3580],
         [-0.7750],
         ...,
         [-0.4162],
         [-0.6686],
         [-1.7315]],

        [[-1.3658],
         [-1.0580],
         [-0.4081],
         ...,
         [-0.7146],
         [ 0.2975],
         [-0.4358]]])


In [130]:
x_s = split(x)

NameError: name 'split' is not defined

In [131]:
key = torch.randn([2,2,4,4])
value = torch.randn([2,2,4,4])
present = torch.stack((key.transpose(-2, -1), value))

In [132]:
print(key)

tensor([[[[ 0.0725, -0.1802, -0.4967, -0.2464],
          [-0.9068,  0.4818, -0.5466,  0.2060],
          [ 0.3689, -0.5151,  0.2059,  1.4127],
          [-0.6065, -1.5314, -0.7181, -0.2939]],

         [[-0.4473,  0.2260, -1.0767, -0.7783],
          [-0.1643, -1.2500, -0.5112,  0.8054],
          [-0.9644, -1.3825, -0.3951, -0.8994],
          [-0.2397,  0.3818, -0.7722,  0.3342]]],


        [[[ 2.1604, -0.1998,  0.9507,  0.4245],
          [ 0.3155, -0.8320,  0.4615,  0.9241],
          [-0.1537, -1.5395, -0.2687, -0.1006],
          [ 1.3486,  0.1883,  0.1001,  0.4842]],

         [[-0.9200,  0.1660, -0.5125, -0.9595],
          [-0.2118,  1.2924,  0.4168, -0.2691],
          [ 0.4307, -0.6666, -1.0174,  1.4210],
          [-1.9475, -0.9043,  1.0805, -1.9117]]]])


In [133]:
print(value)

tensor([[[[-0.0616, -2.3956, -0.3305,  1.0422],
          [ 1.2706,  0.0987, -0.0407, -1.1116],
          [-0.1327,  0.0740, -0.7086, -0.7971],
          [ 1.0783,  0.4304,  1.3829,  0.0207]],

         [[ 0.6801,  0.7327,  0.1529,  0.5038],
          [-1.0100,  0.3607, -0.5664, -1.7524],
          [ 0.1580,  0.0985, -2.1577,  0.4655],
          [ 0.9390, -0.8309,  0.1500,  0.8701]]],


        [[[-0.3441,  0.1697,  0.2136, -0.2131],
          [-0.3742,  0.4407,  0.1253, -0.5867],
          [-2.2655,  0.4139, -0.7282,  1.7367],
          [ 1.0252,  0.3223, -1.0377,  1.7965]],

         [[-0.0986,  1.4246, -0.6832, -0.4942],
          [-0.3007,  0.9262, -0.5225,  1.0475],
          [-0.5953,  0.0113, -1.0411, -0.7125],
          [-0.7568, -0.2003,  1.0781,  0.4612]]]])


In [134]:
print(present)

tensor([[[[[ 0.0725, -0.9068,  0.3689, -0.6065],
           [-0.1802,  0.4818, -0.5151, -1.5314],
           [-0.4967, -0.5466,  0.2059, -0.7181],
           [-0.2464,  0.2060,  1.4127, -0.2939]],

          [[-0.4473, -0.1643, -0.9644, -0.2397],
           [ 0.2260, -1.2500, -1.3825,  0.3818],
           [-1.0767, -0.5112, -0.3951, -0.7722],
           [-0.7783,  0.8054, -0.8994,  0.3342]]],


         [[[ 2.1604,  0.3155, -0.1537,  1.3486],
           [-0.1998, -0.8320, -1.5395,  0.1883],
           [ 0.9507,  0.4615, -0.2687,  0.1001],
           [ 0.4245,  0.9241, -0.1006,  0.4842]],

          [[-0.9200, -0.2118,  0.4307, -1.9475],
           [ 0.1660,  1.2924, -0.6666, -0.9043],
           [-0.5125,  0.4168, -1.0174,  1.0805],
           [-0.9595, -0.2691,  1.4210, -1.9117]]]],



        [[[[-0.0616, -2.3956, -0.3305,  1.0422],
           [ 1.2706,  0.0987, -0.0407, -1.1116],
           [-0.1327,  0.0740, -0.7086, -0.7971],
           [ 1.0783,  0.4304,  1.3829,  0.0207]],

    

In [5]:
import torch
a = torch.tril(torch.ones(1024,1024).view(1, 1, 1024, 1024))

In [6]:
print(a)

tensor([[[[1., 0., 0.,  ..., 0., 0., 0.],
          [1., 1., 0.,  ..., 0., 0., 0.],
          [1., 1., 1.,  ..., 0., 0., 0.],
          ...,
          [1., 1., 1.,  ..., 1., 0., 0.],
          [1., 1., 1.,  ..., 1., 1., 0.],
          [1., 1., 1.,  ..., 1., 1., 1.]]]])


In [12]:
a.size()

torch.Size([1, 1, 1024, 1024])