In [2]:
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm, trange
import pickle
import numpy as np
import os
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from rich.progress import track
from tqdm import tqdm, trange
# from torch.utils.data import DataLoader, Dataset
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
# from torch.utils import tensorboard
import tensorboardX
import random
from matplotlib import pyplot as plt

In [3]:
class ReplayBuffer(Dataset):
    def __init__(self, data):
        self.data = data
        self.length = len(data)
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        return self.data[idx]

In [4]:
from torch_geometric.nn.models import GraphSAGE

class GraphSAGE_(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout, num_layers):
        super(GraphSAGE_, self).__init__()
        self.sage = GraphSAGE(in_channels=in_channels, hidden_channels=hidden_channels, out_channels=hidden_channels, dropout=dropout, num_layers=num_layers)
        self.fc = torch.nn.Linear(hidden_channels, out_channels)
        
    def forward(self, x, edge_index, batch):
        x = self.sage(x, edge_index)
        x = global_mean_pool(x, batch)
        # x = self.fc(x)
        return x

In [5]:
state_encoder = torch.load('sage128.pth')
state_encoder.eval()

GraphSAGE_(
  (sage): GraphSAGE(-1, 128, num_layers=8)
  (fc): Linear(in_features=128, out_features=1, bias=True)
)

In [6]:
pickle.load(open('project\\task2\\project_data2\\adder_0.pkl', 'rb'))

{'input': ['adder_',
  'adder_4',
  'adder_42',
  'adder_423',
  'adder_4234',
  'adder_42345',
  'adder_423455',
  'adder_4234552',
  'adder_42345526',
  'adder_423455260',
  'adder_4234552604'],
 'target': [0.02755838064327021,
  0.00795755516894936,
  0.00795755516894936,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0]}

In [7]:
ratio = 4
class PolicyNet(nn.Module):
    def __init__(self, in_dim, out_dim) -> None:
        super().__init__()
        
        self.fc1 = nn.Linear(in_dim, ratio * in_dim)
        self.fc2 = nn.Linear(ratio * in_dim, ratio * in_dim)
        self.fc3 = nn.Linear(ratio * in_dim, in_dim)
        self.fc4 = nn.Linear(in_dim, out_dim)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return F.softmax(x, dim=1)

class ValueNet(nn.Module):
    def __init__(self, in_dim, out_dim) -> None:
        super().__init__()
        
        self.fc1 = nn.Linear(in_dim, ratio * in_dim)
        self.fc2 = nn.Linear(ratio * in_dim, ratio * in_dim)
        self.fc3 = nn.Linear(ratio * in_dim, in_dim)
        self.fc4 = nn.Linear(in_dim, out_dim)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x)) 
        x = self.fc4(x)
        return x

In [8]:
policy = PolicyNet(in_dim=128, out_dim=7).cuda()
value = ValueNet(in_dim=128, out_dim=1).cuda()

In [9]:
replay_buffer = []
_, _, file_list = next(os.walk('project\\task2\\project_data2'))
print(type(file_list))
block_size = 4000

state_encoder = state_encoder.cpu()

def generate_replay_buffer(idx):
    replay_buffer = []
    data = pickle.load(open('train_data_tensor_' + str(idx) + '.pkl', 'rb'))
    files = file_list[idx * block_size: (idx + 1) * block_size]
    j = -1
    for f in tqdm(files):
        f_data = pickle.load(open('project\\task2\\project_data2\\' + f, 'rb'))
        sz = len(f_data['input'])
        ops = f_data['input'][-1].split('_')[-1]
        # print(ops, sz)
        
        assert len(ops) == sz - 1
        
        for i in range(1, sz):
            # d = Data(x=data[j + i]['x'], edge_index=data[j + i]['edge_index'])
            curState = state_encoder(x=data[j + i]['x'].float(), edge_index=data[j + i]['edge_index'], batch=None).cpu()
            # curState = state_encoder(d.x.float().cuda(), d.edge_index.cuda(), d.batch).cpu()
            # d = Data(x=data[j + i + 1]['x'], edge_index=data[j + i + 1]['edge_index'])
            # nextState = state_encoder(d.x.float().cuda(), d.edge_index.cuda(), d.batch).cpu()
            nextState = state_encoder(x=data[j + i + 1]['x'].float(), edge_index=data[j + i + 1]['edge_index'], batch=None).cpu()
            op = ops[i - 1]
            reward = data[j + i]['y']
            assert curState.device == torch.device('cpu')
            assert f_data['target'][i - 1] == reward
            replay_buffer.append((curState, op, nextState, reward))
            
        j += sz
        
    assert len(replay_buffer) == 40000
    return DataLoader(Dataset(replay_buffer), batch_size=16, shuffle=True)

<class 'list'>


In [10]:
# graph, target
dataset_list = np.array([0, 1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13, 16, 17])

In [11]:
optimizer1 = torch.optim.AdamW(policy.parameters(), lr=0.001, weight_decay=1e-5)
optimizer2 = torch.optim.AdamW(value.parameters(), lr=0.001, weight_decay=1e-5)
loss1 = []
loss2 = []
for idx in dataset_list:
    dataloader = generate_replay_buffer(idx)
    for i, (curState, op, nextState, reward) in enumerate(tqdm(dataloader)):
        optimizer1.zero_grad()
        optimizer2.zero_grad()
        curValue = value(curState.cuda())
        nextValue = value(nextState.cuda())
        op_ = policy(curState.cuda())
        y = nextValue + reward
        td_error = y - curValue
        policy_loss = F.cross_entropy(op_, op.cuda()) * td_error
        value_loss = td_error * curValue
        policy_loss.backward()
        optimizer1.step()
        value_loss.backward()
        optimizer2.step()
        if i % 100 == 99:
            loss1.append(policy_loss.item())
            loss2.append(value_loss.item())
plt.plot(loss1)
plt.plot(loss2)
plt.show()

 11%|█▏        | 450/4000 [03:23<26:43,  2.21it/s]  
