In [60]:
from collections import Counter
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import Dataset



print("start")

class FlagsDataset(Dataset):
    def __init__(self, path: str) -> None:
        super().__init__()
        
        self._data, self._vocab, self._unique_flags = self._load_data(path)
        self._word2idx = {word: idx + 1 for idx, word in enumerate(self._vocab)}
        self._flag2idx = {flag: idx for idx, flag in enumerate(self._unique_flags)}
        
        self.vocab_size = len(self._vocab) + 1
        self.n_classes = len(self._unique_flags)
        
    def _load_data(self, path: str) -> List:
        vocab, unique_flags = Counter(), Counter()
        data = []
        file = open(path, 'r')
        for line in file:
            flags = (line[line.index('[') + 1:line.index(']')]).split(', ')
            flags = [self._filter(flag) for flag in flags]
            unique_flags += Counter(flags)

            file_paths = (line[line.index('] [') + 3:]).replace(']', '').split(', ')
            file_paths = [self._tokenize(self._filter(file_path)) for file_path in file_paths]
            file_paths = [item for sublist in file_paths for item in sublist]
            vocab += Counter(file_paths)

            data.append((file_paths, flags))

        vocab = sorted(vocab, key=vocab.get, reverse=True)
        unique_flags = sorted(unique_flags, key=unique_flags.get, reverse=True)

        return data, vocab, unique_flags
    
    def _tokenize(self, text: str) -> List:
        return text.split('/')[1:]
    
    def _filter(self, text: str) -> str:
        text = text.replace('\n', '')
        text = text.replace("'", '')
        return text
    
    def __len__(self, ) -> int:
        return len(self._data)
    
    def __getitem__(self, idx: int) -> Tuple[List, List]:
        paths, flags = self._data[idx]
        
        inputs = torch.tensor([self._word2idx[word] for word in paths], dtype=torch.int64)
        #inputs = F.pad(inputs, (0, self._max_len - len(inputs)), value=0)
        
        labels = torch.zeros(self.n_classes, dtype=torch.float)
        labels[[self._flag2idx[flag] for flag in flags]] = 1
        
        return inputs, labels

start


In [2]:
class FlagClassifier(nn.Module):
    def __init__(self, vocab_size, n_classes, emb_dim=64, hidden_dim=128) -> None:
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.emb_layer = nn.Embedding(vocab_size, emb_dim)
        self.lstm = nn.LSTM(emb_dim, hidden_dim)
        
        self.linear = nn.Linear(hidden_dim, n_classes)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.emb_layer(x)
        x, _ = self.lstm(x)
        x = self.linear(x[-1])
        return x

In [3]:
def collate_fn(batch: List[Tuple[torch.Tensor, torch.Tensor]]) -> torch.Tensor:
    inputs = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    return nn.utils.rnn.pad_sequence(inputs, padding_value=0), torch.stack(labels, axis=0)

dataset = FlagsDataset("dataset.txt")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, collate_fn=collate_fn, shuffle=True)

device = torch.device("cpu")
model = FlagClassifier(dataset.vocab_size, dataset.n_classes, dataset._max_len).to(device)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters())

In [4]:
epochs = 100

for epoch in range(epochs):
    epoch_loss = 0.
    for batch in tqdm(dataloader):
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)
        
        logits = model(inputs)
        loss = loss_fn(logits, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        
    print(f"Epoch {epoch+1} finished, train loss: {epoch_loss / len(dataloader)}")

100%|██████████| 43/43 [00:46<00:00,  1.08s/it]


Epoch 1 finished, train loss: 0.2637136592421421


100%|██████████| 43/43 [00:59<00:00,  1.38s/it]


Epoch 2 finished, train loss: 0.045643650775038916


100%|██████████| 43/43 [01:55<00:00,  2.68s/it]


Epoch 3 finished, train loss: 0.027999386234685432


100%|██████████| 43/43 [00:46<00:00,  1.07s/it]


Epoch 4 finished, train loss: 0.024610680376374445


100%|██████████| 43/43 [00:20<00:00,  2.07it/s]


Epoch 5 finished, train loss: 0.02359141015209431


100%|██████████| 43/43 [00:15<00:00,  2.69it/s]


Epoch 6 finished, train loss: 0.023209836333990097


100%|██████████| 43/43 [00:13<00:00,  3.13it/s]


Epoch 7 finished, train loss: 0.023027042360153308


100%|██████████| 43/43 [00:13<00:00,  3.19it/s]


Epoch 8 finished, train loss: 0.022926568941668023


100%|██████████| 43/43 [00:16<00:00,  2.57it/s]


Epoch 9 finished, train loss: 0.022867907011924787


100%|██████████| 43/43 [00:13<00:00,  3.12it/s]


Epoch 10 finished, train loss: 0.02282653081902238


100%|██████████| 43/43 [00:10<00:00,  3.97it/s]


Epoch 11 finished, train loss: 0.022778374900998072


100%|██████████| 43/43 [00:10<00:00,  3.99it/s]


Epoch 12 finished, train loss: 0.022654397802990535


100%|██████████| 43/43 [00:10<00:00,  4.14it/s]


Epoch 13 finished, train loss: 0.022505732643049815


100%|██████████| 43/43 [00:18<00:00,  2.34it/s]


Epoch 14 finished, train loss: 0.022382597893823023


100%|██████████| 43/43 [00:14<00:00,  2.98it/s]


Epoch 15 finished, train loss: 0.022209780810530797


100%|██████████| 43/43 [00:11<00:00,  3.73it/s]


Epoch 16 finished, train loss: 0.022078998213590578


100%|██████████| 43/43 [00:10<00:00,  4.02it/s]


Epoch 17 finished, train loss: 0.021944463079751925


100%|██████████| 43/43 [00:10<00:00,  4.27it/s]


Epoch 18 finished, train loss: 0.021717366567531298


100%|██████████| 43/43 [00:09<00:00,  4.36it/s]


Epoch 19 finished, train loss: 0.021541015136727067


100%|██████████| 43/43 [00:09<00:00,  4.72it/s]


Epoch 20 finished, train loss: 0.02128559272996215


100%|██████████| 43/43 [00:09<00:00,  4.55it/s]


Epoch 21 finished, train loss: 0.020864117977231048


100%|██████████| 43/43 [00:09<00:00,  4.46it/s]


Epoch 22 finished, train loss: 0.020456047035580457


100%|██████████| 43/43 [00:09<00:00,  4.44it/s]


Epoch 23 finished, train loss: 0.020167776924926183


100%|██████████| 43/43 [00:15<00:00,  2.76it/s]


Epoch 24 finished, train loss: 0.019891838867997013


100%|██████████| 43/43 [00:13<00:00,  3.17it/s]


Epoch 25 finished, train loss: 0.019595871051383572


100%|██████████| 43/43 [00:11<00:00,  3.68it/s]


Epoch 26 finished, train loss: 0.019279249532278196


100%|██████████| 43/43 [00:09<00:00,  4.64it/s]


Epoch 27 finished, train loss: 0.018876210504839588


100%|██████████| 43/43 [00:09<00:00,  4.42it/s]


Epoch 28 finished, train loss: 0.018407802595648656


100%|██████████| 43/43 [00:17<00:00,  2.50it/s]


Epoch 29 finished, train loss: 0.017927492318978142


100%|██████████| 43/43 [00:12<00:00,  3.53it/s]


Epoch 30 finished, train loss: 0.017488854179202123


100%|██████████| 43/43 [00:10<00:00,  4.10it/s]


Epoch 31 finished, train loss: 0.01711505398067624


100%|██████████| 43/43 [00:09<00:00,  4.61it/s]


Epoch 32 finished, train loss: 0.01671597370228102


100%|██████████| 43/43 [00:09<00:00,  4.68it/s]


Epoch 33 finished, train loss: 0.01628219953543225


100%|██████████| 43/43 [00:09<00:00,  4.64it/s]


Epoch 34 finished, train loss: 0.015770976041811845


100%|██████████| 43/43 [00:09<00:00,  4.58it/s]


Epoch 35 finished, train loss: 0.015316356753194055


100%|██████████| 43/43 [00:09<00:00,  4.74it/s]


Epoch 36 finished, train loss: 0.014947118497518607


100%|██████████| 43/43 [00:09<00:00,  4.53it/s]


Epoch 37 finished, train loss: 0.01451470442982607


100%|██████████| 43/43 [00:09<00:00,  4.68it/s]


Epoch 38 finished, train loss: 0.014268360165662543


100%|██████████| 43/43 [00:13<00:00,  3.31it/s]


Epoch 39 finished, train loss: 0.013968021024105161


100%|██████████| 43/43 [00:15<00:00,  2.86it/s]


Epoch 40 finished, train loss: 0.013679721760888433


100%|██████████| 43/43 [00:11<00:00,  3.59it/s]


Epoch 41 finished, train loss: 0.013388315410634805


100%|██████████| 43/43 [00:09<00:00,  4.64it/s]


Epoch 42 finished, train loss: 0.01318523961357599


100%|██████████| 43/43 [00:09<00:00,  4.65it/s]


Epoch 43 finished, train loss: 0.012906751595437527


100%|██████████| 43/43 [00:11<00:00,  3.64it/s]


Epoch 44 finished, train loss: 0.012646295909964762


100%|██████████| 43/43 [00:15<00:00,  2.72it/s]


Epoch 45 finished, train loss: 0.012382464916553608


100%|██████████| 43/43 [00:12<00:00,  3.44it/s]


Epoch 46 finished, train loss: 0.012137643754655538


100%|██████████| 43/43 [00:11<00:00,  3.86it/s]


Epoch 47 finished, train loss: 0.011935095042856626


100%|██████████| 43/43 [00:48<00:00,  1.13s/it]


Epoch 48 finished, train loss: 0.011744829154638358


100%|██████████| 43/43 [00:42<00:00,  1.01it/s]


Epoch 49 finished, train loss: 0.011571621695576711


100%|██████████| 43/43 [00:25<00:00,  1.68it/s]


Epoch 50 finished, train loss: 0.011342819065375383


100%|██████████| 43/43 [00:09<00:00,  4.64it/s]


Epoch 51 finished, train loss: 0.011132192191516245


100%|██████████| 43/43 [00:09<00:00,  4.74it/s]


Epoch 52 finished, train loss: 0.010948054622425589


100%|██████████| 43/43 [00:09<00:00,  4.50it/s]


Epoch 53 finished, train loss: 0.01076779291466918


100%|██████████| 43/43 [00:08<00:00,  4.82it/s]


Epoch 54 finished, train loss: 0.010565367917162042


100%|██████████| 43/43 [00:09<00:00,  4.68it/s]


Epoch 55 finished, train loss: 0.010418543743706026


100%|██████████| 43/43 [00:09<00:00,  4.54it/s]


Epoch 56 finished, train loss: 0.01023105941279683


100%|██████████| 43/43 [00:09<00:00,  4.66it/s]


Epoch 57 finished, train loss: 0.010012224955527588


100%|██████████| 43/43 [00:09<00:00,  4.77it/s]


Epoch 58 finished, train loss: 0.009877057951810054


100%|██████████| 43/43 [00:09<00:00,  4.68it/s]


Epoch 59 finished, train loss: 0.009714350569993258


100%|██████████| 43/43 [00:09<00:00,  4.70it/s]


Epoch 60 finished, train loss: 0.009557869512761054


100%|██████████| 43/43 [00:09<00:00,  4.58it/s]


Epoch 61 finished, train loss: 0.009465699181567097


100%|██████████| 43/43 [00:09<00:00,  4.59it/s]


Epoch 62 finished, train loss: 0.009252642530428117


100%|██████████| 43/43 [00:09<00:00,  4.74it/s]


Epoch 63 finished, train loss: 0.00909142535136536


100%|██████████| 43/43 [00:09<00:00,  4.77it/s]


Epoch 64 finished, train loss: 0.008972412224339192


100%|██████████| 43/43 [00:14<00:00,  3.02it/s]


Epoch 65 finished, train loss: 0.008824192593957102


100%|██████████| 43/43 [00:14<00:00,  3.06it/s]


Epoch 66 finished, train loss: 0.008727577222554489


100%|██████████| 43/43 [00:11<00:00,  3.81it/s]


Epoch 67 finished, train loss: 0.008562948738852905


100%|██████████| 43/43 [00:10<00:00,  4.28it/s]


Epoch 68 finished, train loss: 0.008490206738717334


100%|██████████| 43/43 [00:16<00:00,  2.66it/s]


Epoch 69 finished, train loss: 0.008354457257705372


100%|██████████| 43/43 [00:10<00:00,  4.20it/s]


Epoch 70 finished, train loss: 0.008274479725852955


100%|██████████| 43/43 [00:09<00:00,  4.76it/s]


Epoch 71 finished, train loss: 0.008179516509868378


100%|██████████| 43/43 [00:08<00:00,  4.80it/s]


Epoch 72 finished, train loss: 0.008046386958381464


100%|██████████| 43/43 [00:08<00:00,  4.78it/s]


Epoch 73 finished, train loss: 0.007952522410642962


100%|██████████| 43/43 [00:08<00:00,  4.97it/s]


Epoch 74 finished, train loss: 0.007864936883019846


100%|██████████| 43/43 [00:08<00:00,  4.95it/s]


Epoch 75 finished, train loss: 0.00777162738187715


100%|██████████| 43/43 [00:17<00:00,  2.49it/s]


Epoch 76 finished, train loss: 0.007712434316703746


100%|██████████| 43/43 [00:11<00:00,  3.64it/s]


Epoch 77 finished, train loss: 0.007615147884068794


100%|██████████| 43/43 [00:09<00:00,  4.39it/s]


Epoch 78 finished, train loss: 0.00750651829984299


100%|██████████| 43/43 [00:08<00:00,  4.86it/s]


Epoch 79 finished, train loss: 0.007434897817844568


100%|██████████| 43/43 [00:09<00:00,  4.75it/s]


Epoch 80 finished, train loss: 0.007327829736696426


100%|██████████| 43/43 [00:08<00:00,  4.83it/s]


Epoch 81 finished, train loss: 0.007262628070663574


100%|██████████| 43/43 [00:09<00:00,  4.77it/s]


Epoch 82 finished, train loss: 0.007173952256697555


100%|██████████| 43/43 [00:09<00:00,  4.71it/s]


Epoch 83 finished, train loss: 0.007139477954614301


100%|██████████| 43/43 [00:09<00:00,  4.64it/s]


Epoch 84 finished, train loss: 0.007049035551676223


100%|██████████| 43/43 [00:16<00:00,  2.66it/s]


Epoch 85 finished, train loss: 0.0069654259267588


100%|██████████| 43/43 [00:11<00:00,  3.67it/s]


Epoch 86 finished, train loss: 0.006874017926409494


100%|██████████| 43/43 [00:11<00:00,  3.90it/s]


Epoch 87 finished, train loss: 0.006852972093796314


100%|██████████| 43/43 [00:09<00:00,  4.64it/s]


Epoch 88 finished, train loss: 0.006737322319125713


100%|██████████| 43/43 [00:17<00:00,  2.42it/s]


Epoch 89 finished, train loss: 0.0066423020542187745


100%|██████████| 43/43 [00:11<00:00,  3.64it/s]


Epoch 90 finished, train loss: 0.006595588787350544


100%|██████████| 43/43 [00:08<00:00,  4.81it/s]


Epoch 91 finished, train loss: 0.00653621971303987


100%|██████████| 43/43 [00:09<00:00,  4.78it/s]


Epoch 92 finished, train loss: 0.006497042134490817


100%|██████████| 43/43 [00:08<00:00,  4.81it/s]


Epoch 93 finished, train loss: 0.006451999411246804


100%|██████████| 43/43 [00:15<00:00,  2.69it/s]


Epoch 94 finished, train loss: 0.006359523920299009


100%|██████████| 43/43 [00:11<00:00,  3.64it/s]


Epoch 95 finished, train loss: 0.006310492407443912


100%|██████████| 43/43 [00:10<00:00,  4.04it/s]


Epoch 96 finished, train loss: 0.006235544526473034


100%|██████████| 43/43 [00:18<00:00,  2.37it/s]


Epoch 97 finished, train loss: 0.00615760374294464


100%|██████████| 43/43 [00:10<00:00,  4.00it/s]


Epoch 98 finished, train loss: 0.006103636258346743


100%|██████████| 43/43 [00:09<00:00,  4.48it/s]


Epoch 99 finished, train loss: 0.006045668840754864


100%|██████████| 43/43 [00:09<00:00,  4.73it/s]

Epoch 100 finished, train loss: 0.005998720902256494





In [6]:
model = model.cpu()

print((model(dataset[0][0]).sigmoid() > 0.5).nonzero())
print(dataset[0][1].nonzero())
print(model(dataset[0][0]))
print(dataset[0][0])




tensor([[ 7],
        [24]])
tensor([[ 7],
        [24]])
tensor([ -6.0549,  -1.7437,  -8.8945,  -6.5495, -12.1324,  -9.5775,  -9.8230,
          1.6025,  -7.1738, -11.0155,  -9.0088,  -1.8157,  -8.9323,  -5.6583,
         -4.3313,  -6.2582,  -4.8948,  -9.1755,  -4.7792,  -6.1798,  -4.2904,
         -7.3718,  -5.4732,  -6.3603,   1.0459,  -3.9689,  -7.9345,  -3.2641,
         -8.0317,  -4.6927, -12.7539,  -6.9042, -11.8692,  -6.0222,  -8.3910,
         -8.0081,  -7.9376,  -6.0219,  -5.7270,  -9.2602,  -4.6173,  -6.9237,
         -8.2855,  -5.9983,  -9.6159,  -8.6556,  -9.7714,  -9.9828,  -8.5212,
         -4.1035,  -4.6489,  -3.8393,  -8.9859,  -7.0772, -10.0190,  -8.1355,
         -8.1221,  -6.0019,  -5.7970,  -8.8273,  -9.7669,  -7.9812,  -7.7109,
        -10.7248, -10.5166, -10.0818,  -4.0407,  -1.6661,  -8.6470,  -7.8943,
         -9.7733,  -4.7770,  -7.8542,  -5.6136,  -6.8574,  -7.3624,  -9.2182,
         -8.2442,  -9.1187,  -7.1989,  -7.4427,  -5.8105,  -9.0060,  -4.0272,
      

In [13]:
idx2flag = {idx: flag for flag, idx in dataset._flag2idx.items()}
for t in (model(dataset[0][0]).sigmoid() > 0.5).nonzero().detach().cpu().numpy().tolist():
    print(idx2flag[t[0]])

clang
Interp


In [15]:
torch.save(model, "./flags_classification.pt")

In [64]:
model2 = torch.load("flags_classification.pt")
model2.eval()
dataset2 = FlagsDataset("dataset.txt")

def find_name(line):
    start = line.index('/')
    end = line.rindex(' ')
    return line[start:end]

def parser(name, result):
    file = open(name, 'r')
    name = ""
    read_file = 0
    for line in file:
        if line.find("diff --git") != -1 and read_file == 0:
            name = find_name(line)
            read_file = 1
            result.append(name)
        elif line.find("diff --git") != -1 and read_file == 1:
            name = ""
            read_file = 0
        elif line[0] == "+" and read_file == 1:
            pass
            #print(line, end= ' ')
    file.close()

lst = []
parser("super_risc.test", lst)

idx2flag = {idx: flag for flag, idx in dataset2._flag2idx.items()}

tensor = []

for path in lst:
    path = path.split("/")
    for elem in path:
        if elem in dataset2._vocab:
            tensor.append(dataset2._vocab.index(elem) + 1)




for t in (model2(torch.Tensor(tensor).int()).sigmoid() > 0.5).nonzero().detach().cpu().numpy().tolist():
    print(idx2flag[t[0]])