In [1]:
import re
import seaborn as sns
import threading
import torch
import pandas as pd
from tqdm import tqdm
from copy import deepcopy

from dataclasses import dataclass, fields, asdict
from abc import ABC, abstractmethod
from model.model import PageAccModel
from logfile_reader import read_pages, save_pages_accs, read_optimal_results, Page

# Load data

In [5]:
TRAIN_PART = 0.7
BATCH_SIZE = 2
BUFFER_SIZE = 128

In [6]:
pages = read_pages("train_data/tpcc_logfile")
train_size = int(len(pages) * TRAIN_PART)
train_pages = pages[:train_size]
del pages

In [9]:
optimal_results = read_optimal_results("train_data/tpcc_logfile_train_victims")

In [10]:
assert(len(optimal_results) == len(train_pages))
print(len(train_pages))

1547962


# Train

In [18]:
def get_model_optimal_res(pages, buffer, current_index):
    res = [0] * (len(buffer))
    if (current_index >= len(pages)):
        print(f"ERROR: current_index=={current_index} pages.size() == {len(pages)}")

    if len(optimal_results[current_index]) == 0:
        page_in_buffer = next(filter(lambda el: el[1].get_page_id() == pages[current_index].get_page_id(), enumerate(buffer)), None)
        return res, page_in_buffer[0]
    
    victims_rates = optimal_results[current_index]
    res[victims_rates[0][0]] = 1

    return res, victims_rates[0][0]

In [13]:
def get_train_data(pages, buffer, batch_start, batch_end):
    pages_acc = torch.Tensor([list(asdict(page).values()) for page in pages[batch_start:batch_end]])

    buffers = []
    optimal_predictions = []
    hit_fail_mask = []

    for i in range(batch_start, batch_end):
        buffers.append([value for obj in buffer for value in asdict(obj).values()])

        res, victim = get_model_optimal_res(pages, buffer, i)
        optimal_predictions.append(res)

        if sum(res) > 0:
            buffer[victim] = deepcopy(pages[i])
            buffer[victim].hit = victim
            pages_acc[i - batch_start][-1] = BUFFER_SIZE
            hit_fail_mask.append(1)
        else:
            pages_acc[i - batch_start][-1] = victim
            hit_fail_mask.append(0)
    
    return pages_acc, torch.Tensor(buffers), torch.Tensor(optimal_predictions), buffer, torch.tensor(hit_fail_mask, dtype=torch.bool)

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [15]:
model = PageAccModel(len(fields(Page)), 256, 512, BUFFER_SIZE).to(device)

In [None]:
model.load_state_dict(torch.load("drive/MyDrive/model1.pth", map_location=device, weights_only=True))
model.to(device)

In [16]:
loss = torch.nn.CrossEntropyLoss(reduction='none') # Установим 'none' для получения потерь по каждому элементу
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [None]:
optimizer.load_state_dict(torch.load("drive/MyDrive/opt1.pth", map_location=device, weights_only=True))

In [20]:
# f = open("train.txt", "w")

h, c = None, None
model.train()
for epoch in range(20):
    buffer = [Page(-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0)] * BUFFER_SIZE

    loss_sum = 0
    pbar = tqdm(range(0, len(train_pages), BATCH_SIZE))
    for i in pbar:
        batch_start = i
        # batch_end = i + BATCH_SIZE if i + BATCH_SIZE < TRAIN_SIZE else TRAIN_SIZE
        batch_end = i + BATCH_SIZE
        if batch_end >= len(train_pages):
            continue
        pages_acc, buffers, optimal_predictions, buffer, hit_fail_mask = get_train_data(train_pages, buffer, batch_start, batch_end)
        optimal_predictions = torch.argmax(optimal_predictions, dim=1)
    
        optimizer.zero_grad()

        out, h, c = model.forward(pages_acc.to(device), buffers.to(device), False, h, c)
        h.to(device)
        c.to(device)

        if any(hit_fail_mask):
            losses = loss(out, optimal_predictions.to(device))
            masked_losses = losses[hit_fail_mask]
            loss_value = masked_losses.mean()
    
            loss_value.backward()

            optimizer.step()

            loss_sum += loss_value.item()
            loss_avg = loss_sum / (batch_end // BATCH_SIZE)

            pbar.set_postfix_str(f"loss={loss_avg}")

        # f.write("=========================\n")
        # for i in range(len(hit_fail_mask)):
        #     if hit_fail_mask[i]:
        #         f.write(f"{i}. ////\n")
        #         f.write(f"{out[i]}\n")
        #         f.write(f"{optimal_predictions[i]}\n")

        #         for name, param in model.named_parameters():
        #             if param.grad is not None:
        #                 f.write(f"{name}: {param.grad.abs().mean()}\n")

        #         f.write("////\n")
        # f.write("=========================\n")

        h = h.detach()
        c = c.detach()

    torch.save(model.state_dict(), "drive/MyDrive/model1.pth")
    torch.save(optimizer.state_dict(), "drive/MyDrive/opt1.pth")

f.close()

  0%|          | 377/773981 [00:20<11:26:40, 18.78it/s, loss=0.7860872256977173]


KeyboardInterrupt: 