<a href="https://colab.research.google.com/github/gauss5930/Natural-Language-Processing/blob/main/ELMo/char_cnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#char_cnn
import torch
import torch.nn as nn
from typing import List

class CharEmbedding(nn.Module):
  def __init__(self, vocab_size, emb_dim, prj_dim, kernel_sizes, char_len, device):
    super().__init__()
    self.device = device
    self.kernel_dim = sum([kernel_size for num_features, kernel_size in kernel_sizes])   #embedding dimenstion과 같음
    self.charcnn = CharCNN(vocab_size, emb_dim, self.kernel_dim, kernel_sizes, char_len, device)
    self.highway_net = HighWayNetwork(self.kernel_dim)
    self.highwat_net._init_bias()
    self.projection_layer = nn.Linear(self.kernel_dim, prj_dim)

  def forward(self, x):
    #파라미터: 문장의 캐릭터로 이루어져 있는 문장 벡터
    #차원: [Batch, Seq_len, Char_len]
    batch_size, seq_len, _ = x.size()
    y = torch.zeros(batch_size, seq_len, self.kernel_dim).to(self.device)

    for i in range(seq_len):
      char_emb = self.charcnn(x[:, i, :])
      highway_emb = self.highway_net(char_emb)
      y[:, i, :] = highway_emb.squeeze(1)

    emb = self.projection_layer(y)
    return emb

class CharCNN(nn.Module):
  def __init__(self, vocab_size, char_emb_dim, word_emb_dim, kernel_sizes, char_len, device):
    super(CharCNN, self).__init__()
    self.device = device
    self.char_len = char_len
    self.word_emb_dim = word_emb_dim
    self.kernel_sizes = kernel_sizes

    self.embedding = nn.Embedding(vocab_size, char_meb_dim)
    self.kernels = nn.ModuleList([nn.Conv1d(in_channels = char_emb_dim, out_channels = num_features,
                                            kernel_size = kernel_size) for kernel_size, num_features in kernel_sizes])

  def forward(self, word):
    #파라미터: word(입력 텐서)
    #차원
    #입력: 단어([Batch, Emb_dim, Seq_len])
    #출력: y([Batch, Kernel_dim])
    batch_size = word.size(0)
    y = torch.zeros(batch_size, self.word_meb_dim).to(self.device)

    cnt = 0   #indec for y

    #torch.cat보다 비어있는 텐서를 채우는 것이 더 빠름
    for kernel in self.kernels:
      emb = self.embedding(word)
      emb = emb.permute(0, 2, 1)
      temp = kernel(emb)
      pooled = torch.max(temp, dim = 2)[0]
      y[:, cnt] = pooled
      cnt += pooled_size(1)

    return y

class HighwayNetwork(nn.Module):
  def __init__(self, kernel_sizes):
    super(HighwayNetwork, self).__init__()
    self.h_gate = nn.Linear(kernel_sizes, kernel_sizes)
    self.t_gate = nn.Sequential(nn.Linear(kernel_sizes, kernel_sizes), nn.Sigmoid())
    self.relu = torch.nn.ReLU()

  def forward(self, x):
    #차원: x(Batch, Kernel_dim)
    x = x.unsqueeze(1)
    h = self.relu(self.h_gate(x))
    t = self.t_gate(x)
    c = 1 - t
    return t * h + c * x

  def _init_bias(self):
    self.t_gate[0].bias.data.fill_(-2)