In [None]:
from timeit import default_timer as timer
import torch
import utilities
import gc

from torchtext.experimental.datasets import WikiText103

from torch.utils.data import DataLoader

from Container import NNContainer

from customLayers import BertEmbedding, BertTransformerEncoderLayer

from torch.nn import Linear

import torch
from typing import Tuple, Optional


class MultiheadAttentionContainer(torch.nn.Module):
    def __init__(self, nhead, in_proj_container, attention_layer, out_proj, batch_first=False):

        super(MultiheadAttentionContainer, self).__init__()
        self.nhead = nhead
        self.in_proj_container = in_proj_container
        self.attention_layer = attention_layer
        self.out_proj = out_proj
        self.batch_first = batch_first

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
                attn_mask: Optional[torch.Tensor] = None,
                bias_k: Optional[torch.Tensor] = None,
                bias_v: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:

        if self.batch_first:
            query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2)

        
        tgt_len, src_len, bsz, embed_dim = query.size(-3), key.size(-3), query.size(-2), query.size(-1)
        
        q, k, v = self.in_proj_container(query, key, value)
        return q, k
        
        
        assert q.size(-1) % self.nhead == 0, "query's embed_dim must be divisible by the number of heads"
        head_dim = q.size(-1) // self.nhead
        q = q.reshape(tgt_len, bsz * self.nhead, head_dim)

        assert k.size(-1) % self.nhead == 0, "key's embed_dim must be divisible by the number of heads"
        head_dim = k.size(-1) // self.nhead
        k = k.reshape(src_len, bsz * self.nhead, head_dim)

        assert v.size(-1) % self.nhead == 0, "value's embed_dim must be divisible by the number of heads"
        head_dim = v.size(-1) // self.nhead
        v = v.reshape(src_len, bsz * self.nhead, head_dim)

        
        
        attn_output, attn_output_weights = self.attention_layer(q, k, v, attn_mask=attn_mask,
                                                                bias_k=bias_k, bias_v=bias_v)
        attn_output = attn_output.reshape(tgt_len, bsz, embed_dim)
        attn_output = self.out_proj(attn_output)

        if self.batch_first:
            attn_output = attn_output.transpose(-3, -2)

        return attn_output, attn_output_weights


class ScaledDotProduct(torch.nn.Module):

    def __init__(self, dropout=0.0, batch_first=False):
        super(ScaledDotProduct, self).__init__()
        self.dropout = dropout
        self.batch_first = batch_first

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
                attn_mask: Optional[torch.Tensor] = None,
                bias_k: Optional[torch.Tensor] = None,
                bias_v: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.batch_first:
            query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2)

        if bias_k is not None and bias_v is not None:
            assert key.size(-1) == bias_k.size(-1) and key.size(-2) == bias_k.size(-2) and bias_k.size(-3) == 1, \
                "Shape of bias_k is not supported"
            assert value.size(-1) == bias_v.size(-1) and value.size(-2) == bias_v.size(-2) and bias_v.size(-3) == 1, \
                "Shape of bias_v is not supported"
            key = torch.cat([key, bias_k])
            value = torch.cat([value, bias_v])
            if attn_mask is not None:
                attn_mask = torch.nn.functional.pad(attn_mask, (0, 1))

        tgt_len, head_dim = query.size(-3), query.size(-1)
        assert query.size(-1) == key.size(-1) == value.size(-1), "The feature dim of query, key, value must be equal."
        assert key.size() == value.size(), "Shape of key, value must match"
        src_len = key.size(-3)
        batch_heads = max(query.size(-2), key.size(-2))

        # Scale query
        query, key, value = query.transpose(-2, -3), key.transpose(-2, -3), value.transpose(-2, -3)
        query = query * (float(head_dim) ** -0.5)
        if attn_mask is not None:
            if attn_mask.dim() != 3:
                raise RuntimeError('attn_mask must be a 3D tensor.')
            if (attn_mask.size(-1) != src_len) or (attn_mask.size(-2) != tgt_len) or \
               (attn_mask.size(-3) != 1 and attn_mask.size(-3) != batch_heads):
                raise RuntimeError('The size of the attn_mask is not correct.')
            if attn_mask.dtype != torch.bool:
                raise RuntimeError('Only bool tensor is supported for attn_mask')

        # Dot product of q, k
        attn_output_weights = torch.matmul(query, key.transpose(-2, -1))
        if attn_mask is not None:
            attn_output_weights.masked_fill_(attn_mask, -1e8,)
        attn_output_weights = torch.nn.functional.softmax(attn_output_weights, dim=-1)
        attn_output_weights = torch.nn.functional.dropout(attn_output_weights, p=self.dropout, training=self.training)
        attn_output = torch.matmul(attn_output_weights, value)

        if self.batch_first:
            return attn_output, attn_output_weights
        else:
            return attn_output.transpose(-3, -2), attn_output_weights


class InProjContainer(torch.nn.Module):
    def __init__(self, query_proj, key_proj, value_proj):
        super(InProjContainer, self).__init__()
        self.query_proj = query_proj
        self.key_proj = key_proj
        self.value_proj = value_proj

    def forward(self,
                query: torch.Tensor,
                key: torch.Tensor,
                value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

        return self.query_proj(query), self.key_proj(key), self.value_proj(value)


def generate_square_subsequent_mask(nbatch, sz):
    r"""Generate a square mask for the sequence. The masked positions are filled with True.
        Unmasked positions are filled with False.
    Args:
        nbatch: the number of batch size
        sz: the size of square mask
    """
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1).repeat(nbatch, 1, 1)
    return mask

In [None]:
start = timer()
print("\nPreparing to load dataset....")
vocab = torch.load("../../vocab/torchtext_bert_vocab.pt")

def process_raw_data(whole_data, frac_ns):
    processed_data = []
    for _idx in range(len(whole_data)):
        if (_idx % 10000 == 0):
            print("Processing data....{}".format(_idx), end='\r', flush=True)
        item = whole_data[_idx]
    
        if isinstance(item, list):
            item = torch.tensor(item)

        if (len(item) > 1):
            # idx to split the text into two sentencd
            split_idx = torch.randint(1, len(item), size=(1, 1)).item()
            # Index 2 means same sentence label. Initial true int(1)
            processed_data.append([item[:split_idx], item[split_idx:], 1])
    # Random shuffle data to have args.frac_ns next sentence set up
    shuffle_idx1 = torch.randperm(len(processed_data))
    shuffle_idx2 = torch.randperm(len(processed_data))
    num_shuffle = int(len(processed_data) * frac_ns)
    shuffle_zip = list(zip(shuffle_idx1, shuffle_idx2))[:num_shuffle]
    for (i, j) in shuffle_zip:
        processed_data[i][1] = processed_data[j][0]
        processed_data[i][2] = int(0)  # Switch same sentence label to false 0
    return processed_data


def collate_batch(batch, bptt, cls_id, sep_id, pad_id):
    # Fix sequence length to args.bptt with padding or trim
    seq_list = []
    tok_type = []
    same_sentence_labels = []
    for item in batch:
        qa_item = torch.cat([item[0], torch.tensor([sep_id]).long(), item[1], torch.tensor([sep_id]).long()])
        if qa_item.size(0) > bptt:
            qa_item = qa_item[:bptt]
        elif qa_item.size(0) < bptt:
            qa_item = torch.cat((qa_item,
                                 torch.tensor([pad_id] * (bptt -
                                                          qa_item.size(0)))))
        seq_list.append(qa_item)
        _tok_tp = torch.ones((qa_item.size(0)))
        _idx = min(len(item[0]) + 1, bptt)
        _tok_tp[:_idx] = 0.0
        tok_type.append(_tok_tp)
        same_sentence_labels.append(item[2])
    seq_input = torch.stack(seq_list).long().t().contiguous()
    seq_input = torch.cat((torch.tensor([[cls_id] * seq_input.size(1)]).long(), seq_input))
    seq_input = seq_input.transpose(0, 1)
    tok_type = torch.stack(tok_type).long().t().contiguous()
    tok_type = torch.cat((torch.tensor([[0] * tok_type.size(1)]).long(), tok_type))
    return seq_input, tok_type, torch.tensor(same_sentence_labels).long().contiguous()



dataset = WikiText103(vocab=vocab, split='valid') # set to train for real testing.pi
dataset = process_raw_data(dataset, frac_ns=0.5)
cls_id = vocab.stoi['<cls>']
pad_id = vocab.stoi['<pad>']
sep_id = vocab.stoi['<sep>']
bptt = 128
end = timer()

print("Dataset Loaded in {} seconds.".format(end-start))

In [None]:
module_1 = BertEmbedding(len(vocab), 768).to("cuda:0")

module_2 = BertTransformerEncoderLayer(768, 16, 1024, 0.5).to("cuda:0")

In [None]:


dataloader = DataLoader(dataset, batch_size=128, shuffle=True,
                            collate_fn=lambda b: collate_batch(b, bptt, cls_id, sep_id, pad_id))

batch = next(iter(dataloader))[0:-1]

batch = [ x.to("cuda:0") for x in batch ]

In [None]:
from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
import torch
import gc

def get_free_space(idx=0):
    nvmlInit()
    h = nvmlDeviceGetHandleByIndex(idx)
    info = nvmlDeviceGetMemoryInfo(h)
    return info.free


linear_layer = torch.nn.Linear(768, 768).to("cuda:0")

with torch.no_grad():
    print(get_free_space(0))
    a_detach = torch.zeros((128, 129, 768)).to("cuda:0")
    d = linear_layer(a_detach)
    del d
    del a_detach
    del linear_layer
    gc.collect()
    torch.cuda.empty_cache()
    print(get_free_space(0))
    