In [59]:
import math
import logging
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging import version
from torch.nn import init
from tensorly.tenalg import multi_mode_dot
from transformers.generation_utils import GenerationMixin
from transformers.file_utils import PushToHubMixin
from transformers.modeling_utils import (ModuleUtilsMixin,
                                         apply_chunking_to_forward, 
                                         find_pruneable_heads_and_indices,
                                         prune_linear_layer)
from transformers.models.bart.modeling_bart import BartEncoderLayer, BartDecoderLayer
from transformers.activations import gelu_new
from transformers.modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    BaseModelOutputWithPoolingAndCrossAttentions, 
    CausalLMOutputWithCrossAttentions
)


import random
import math
import logging
import argparse
from tqdm import tqdm

import yaml
import hydra
from omegaconf import OmegaConf

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_metric, load_from_disk
from transformers import AutoConfig, AutoTokenizer
from transformers import set_seed, get_cosine_schedule_with_warmup, AdamW

from model import GrafomerModel
from utils import preprocess_function_with_setting, load_data, postprocess_text, CustomDataCollator

In [60]:
# device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [113]:

class TeacherWeightGroup:
    teacher_model: nn.Module = None

    @classmethod
    def set_network(cls, teacher_model: nn.Module):
        TeacherWeightGroup.teacher_model = teacher_model

    @classmethod
    def generate_weight_group(
        cls, 
        weight_class_name: str, 
        current_layer_index: int, 
        num_student_layers: int
    ):
        part, weight_class_name = weight_class_name.split(".", 1)
        weight_class_name += ".weight"
        weight_instances = list()

        if part == "encoder":
            for instance_name, instance in TeacherWeightGroup.teacher_model.encoder.named_parameters():
                if "attention" in weight_class_name:
                    if weight_class_name in instance_name and 'attention' in instance_name:
                        weight_instances.append(instance)
                else:
                    if weight_class_name in instance_name and 'attention' not in instance_name:
                        weight_instances.append(instance)

        elif part == "decoder":
            for instance_name, instance in TeacherWeightGroup.teacher_model.decoder.named_parameters():
                if "attention" in weight_class_name:
                    if weight_class_name in instance_name and 'attention' in instance_name:
                        weight_instances.append(instance)
                else:
                    if weight_class_name in instance_name and 'attention' not in instance_name:
                        weight_instances.append(instance)
        
        weight_instances = torch.stack(weight_instances, dim=-1)
        teacher_network_layers = weight_instances.size()[-1]
        
        start = (current_layer_index - 1) * int(teacher_network_layers / num_student_layers)
        end = current_layer_index * int(teacher_network_layers / num_student_layers)
        return weight_instances[:, :, start:end]
    
    @classmethod
    def generate_bias_group(
        cls, 
        weight_class_name: str, 
        current_layer_index: int, 
        num_student_layers: int
    ):
        part, weight_class_name = weight_class_name.split(".", 1)
        weight_class_name += ".bias"
        weight_instances = list()

        if part == "encoder":
            for instance_name, instance in TeacherWeightGroup.teacher_model.encoder.named_parameters():
                if "attention" in weight_class_name:
                    if weight_class_name in instance_name and 'attention' in instance_name:
                        weight_instances.append(instance)
                else:
                    if weight_class_name in instance_name and 'attention' not in instance_name:
                        weight_instances.append(instance)

        elif part == "decoder":
            for instance_name, instance in TeacherWeightGroup.teacher_model.decoder.named_parameters():
                if "attention" in weight_class_name:
                    if weight_class_name in instance_name and 'attention' in instance_name:
                        weight_instances.append(instance)
                else:
                    if weight_class_name in instance_name and 'attention' not in instance_name:
                        weight_instances.append(instance)
        
        weight_instances = torch.stack(weight_instances, dim=-1)
        teacher_network_layers = weight_instances.size()[-1]
        
        start = (current_layer_index - 1) * int(teacher_network_layers / num_student_layers)
        end = current_layer_index * int(teacher_network_layers / num_student_layers)
        return weight_instances[:, start:end]


# new
class WeightGenerator(nn.Module):
    def __init__(
        self,
        weight_class_name: str,
        current_layer_index: int,
        num_student_layers: int,
        student_weight_in: int,
        student_weight_out: int
    ):
        super().__init__()
        self.subset = TeacherWeightGroup.generate_weight_group(
            weight_class_name, current_layer_index, num_student_layers
        )
        teacher_weight_out, teacher_weight_in, num_adjacent_layers = self.subset.size()

        self.W_l = nn.Parameter(torch.empty(num_adjacent_layers, 1))
        self.W = nn.Parameter(torch.ones(student_weight_out, student_weight_in))
        self.B = nn.Parameter(torch.zeros(student_weight_out, student_weight_in))

        self.tanh = nn.Tanh()
        self.init_weights_()
    
    def init_weights_(self):
        init.xavier_uniform_(self.W_l)

    def forward(self) -> nn.Parameter :
        student_param = self.subset.matmul(self.W_l)
        return self.tanh(student_param.squeeze(-1)) * self.W + self.B


# New
class BiasGenerator(nn.Module):
    def __init__(
        self,
        weight_class_name: str,
        current_layer_index: int,
        num_student_layers: int,
        student_out_features: int,
    ):
        super().__init__()
        self.subset = TeacherWeightGroup.generate_bias_group(
            weight_class_name, current_layer_index, num_student_layers
        )
        teacher_out_features, num_adjacent_layers = self.subset.shape

        self.W_l = nn.Parameter(torch.empty(num_adjacent_layers, 1))
        self.W = nn.Parameter(torch.ones(student_out_features))
        self.B = nn.Parameter(torch.zeros(student_out_features))
        

        self.tanh = nn.Tanh()
        self.init_weights_()
    
    def init_weights_(self):
        init.xavier_uniform_(self.W_l)
    
    def forward(self) -> nn.Parameter :
        student_param = self.subset.matmul(self.W_l)
        return self.tanh(student_param.squeeze(-1)) * self.W + self.B


class StudentLinear(nn.Module):
    def __init__(
        self, 
        weight_class_name: str, 
        current_layer_index: int, 
        num_student_layers: int,
        in_features: int, 
        out_features: int,
    ):
        super().__init__()

        self.weight_generator = WeightGenerator(
            weight_class_name = weight_class_name, 
            current_layer_index = current_layer_index, 
            num_student_layers = num_student_layers, 
            student_weight_in = in_features, 
            student_weight_out = out_features,
        )
        self.bias_generator = BiasGenerator(
            weight_class_name = weight_class_name, 
            current_layer_index = current_layer_index, 
            num_student_layers = num_student_layers, 
            student_out_features = out_features,
        )
            
    def forward(self, inputs: torch.Tensor) -> torch.Tensor :

        student_weight = self.weight_generator()
        student_bias = self.bias_generator()
        
        return F.linear(inputs, student_weight, student_bias)


class StudentMLP(nn.Module):
    def __init__(
        self, 
        current_layer_index,
        num_student_layers,
        config
    ):
        super().__init__()
        hidden_size = config["hidden_size"]
        intermediate_size = config["intermediate_size"]
        
        self.c_fc = StudentLinear(
            "decoder.mlp.c_fc", 
            current_layer_index, 
            num_student_layers, 
            hidden_size, 
            intermediate_size
        )
        self.c_proj = StudentLinear(
            "decoder.mlp.c_proj", 
            current_layer_index, 
            num_student_layers, 
            intermediate_size, 
            hidden_size
        )
        self.act = gelu_new
        self.dropout = nn.Dropout(config["resid_pdrop"])
    
    def forward(self, hidden_states):
        hidden_states = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.c_proj(hidden_states)
        hidden_states = self.dropout(hidden_states)
        return hidden_states

In [115]:
from transformers import AutoModel

model = AutoModel.from_pretrained('bert-base-multilingual-cased')

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [116]:
model.to('cuda:0')
TeacherWeightGroup.set_network(model)

In [117]:
torch.cuda.empty_cache()

In [118]:
intermediate = StudentLinear('encoder.intermediate.dense', 2, 2, 768, 3072)
output = StudentLinear('encoder.output.dense', 2, 2, 3072, 768)

In [119]:
intermediate

StudentLinear(
  (weight_generator): WeightGenerator(
    (tanh): Tanh()
  )
  (bias_generator): BiasGenerator(
    (tanh): Tanh()
  )
)

In [120]:
torch.cuda.empty_cache()

In [66]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)


class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 768),
            nn.ReLU(),
            # StudentLinear('encoder.intermediate.dense', 2, 2, 768, 3072).to(device),
            # StudentLinear('encoder.output.dense', 2, 2, 3072, 768).to(device),
            StudentLinear('encoder.intermediate.dense', 2, 2, 768, 3072),
            StudentLinear('encoder.output.dense', 2, 2, 3072, 768),
            nn.ReLU(),
            nn.Linear(768, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork()

In [71]:
learning_rate = 1e-3
batch_size = 64
epochs = 5

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# model.to(device)

def train_loop(dataloader, model, loss_fn, optimizer, epochs):
    model.train()
    size = len(dataloader.dataset)
    # progress_bar = tqdm(range(epochs), ncols=100)

    for batch, (X, y) in enumerate(dataloader):
        # 예측(prediction)과 손실(loss) 계산
        # X.to(device)
        # y.to(device)
        
        pred = model(X)
        loss = loss_fn(pred, y)

        # 역전파
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
        # progress_bar.update()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [72]:
epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer, epochs)
    test_loop(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.201364  [    0/60000]
loss: 1.976470  [ 6400/60000]
loss: 1.656949  [12800/60000]
loss: 1.551333  [19200/60000]
loss: 1.214490  [25600/60000]
loss: 1.107938  [32000/60000]
loss: 1.051304  [38400/60000]
loss: 0.924348  [44800/60000]
loss: 0.933896  [51200/60000]
loss: 0.852989  [57600/60000]
Test Error: 
 Accuracy: 73.7%, Avg loss: 0.819374 

Epoch 2
-------------------------------
loss: 0.796870  [    0/60000]
loss: 0.882160  [ 6400/60000]
loss: 0.600098  [12800/60000]
loss: 0.860670  [19200/60000]
loss: 0.710536  [25600/60000]
loss: 0.669946  [32000/60000]
loss: 0.732355  [38400/60000]
loss: 0.701512  [44800/60000]
loss: 0.718694  [51200/60000]
loss: 0.664167  [57600/60000]
Test Error: 
 Accuracy: 78.1%, Avg loss: 0.634836 

Epoch 3
-------------------------------
loss: 0.543836  [    0/60000]
loss: 0.688140  [ 6400/60000]
loss: 0.436789  [12800/60000]
loss: 0.730121  [19200/60000]
loss: 0.611744  [25600/60000]
loss: 0.569436  [32000/600

KeyboardInterrupt: 

In [15]:
for name, _ in model.named_parameters():
    print(name)

linear_relu_stack.0.weight
linear_relu_stack.0.bias
linear_relu_stack.2.weight_generator.W_l
linear_relu_stack.2.weight_generator.W
linear_relu_stack.2.weight_generator.B
linear_relu_stack.2.bias_generator.W_l
linear_relu_stack.2.bias_generator.W
linear_relu_stack.2.bias_generator.B
linear_relu_stack.3.weight_generator.W_l
linear_relu_stack.3.weight_generator.W
linear_relu_stack.3.weight_generator.B
linear_relu_stack.3.bias_generator.W_l
linear_relu_stack.3.bias_generator.W
linear_relu_stack.3.bias_generator.B
linear_relu_stack.5.weight
linear_relu_stack.5.bias


In [85]:
## TODO: GPU에 올리기
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = NeuralNetwork()
model.to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [87]:
epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.318992  [    0/60000]
loss: 2.298573  [ 6400/60000]
loss: 2.278566  [12800/60000]
loss: 2.257627  [19200/60000]
loss: 2.259730  [25600/60000]
loss: 2.223861  [32000/60000]
loss: 2.230925  [38400/60000]
loss: 2.206994  [44800/60000]
loss: 2.201635  [51200/60000]
loss: 2.157983  [57600/60000]
Test Error: 
 Accuracy: 36.5%, Avg loss: 2.157239 

Epoch 2
-------------------------------
loss: 2.176568  [    0/60000]
loss: 2.162852  [ 6400/60000]
loss: 2.111946  [12800/60000]
loss: 2.119023  [19200/60000]
loss: 2.078768  [25600/60000]
loss: 2.014045  [32000/60000]
loss: 2.039536  [38400/60000]
loss: 1.971239  [44800/60000]
loss: 1.965335  [51200/60000]
loss: 1.898296  [57600/60000]
Test Error: 
 Accuracy: 55.1%, Avg loss: 1.895292 

Epoch 3
-------------------------------
loss: 1.925843  [    0/60000]
loss: 1.901611  [ 6400/60000]
loss: 1.794170  [12800/60000]
loss: 1.831080  [19200/60000]
loss: 1.724633  [25600/60000]
loss: 1.666746  [32000/600

In [133]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 768),
            nn.ReLU(),
            StudentLinear('encoder.intermediate.dense', 2, 2, 768, 3072),
            StudentLinear('encoder.output.dense', 2, 2, 3072, 768),
            nn.ReLU(),
            nn.Linear(768, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = NeuralNetwork()
model.to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [134]:
epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.292444  [    0/60000]
loss: 1.659353  [ 6400/60000]
loss: 1.179987  [12800/60000]
loss: 1.211576  [19200/60000]
loss: 0.877621  [25600/60000]
loss: 0.833095  [32000/60000]
loss: 0.842104  [38400/60000]
loss: 0.770160  [44800/60000]
loss: 0.748130  [51200/60000]
loss: 0.741619  [57600/60000]
Test Error: 
 Accuracy: 76.4%, Avg loss: 0.695336 

Epoch 2
-------------------------------
loss: 0.632415  [    0/60000]
loss: 0.710258  [ 6400/60000]
loss: 0.479420  [12800/60000]
loss: 0.786787  [19200/60000]
loss: 0.627807  [25600/60000]
loss: 0.578678  [32000/60000]
loss: 0.658544  [38400/60000]
loss: 0.665536  [44800/60000]
loss: 0.635270  [51200/60000]
loss: 0.581837  [57600/60000]
Test Error: 
 Accuracy: 80.3%, Avg loss: 0.574849 

Epoch 3
-------------------------------
loss: 0.464922  [    0/60000]
loss: 0.567500  [ 6400/60000]
loss: 0.376697  [12800/60000]
loss: 0.675880  [19200/60000]
loss: 0.539657  [25600/60000]
loss: 0.509184  [32000/600

In [135]:
torch.cuda.empty_cache()