In [1]:
import torch

In [185]:
def parallel_inference(model, idx, parallel, memory_size, max_steps, stop_idx=-1):
    # Parallel inference with  Triangular Anderson acceleration

    # Picard iteration:
    # x^k+1 = f(x^k, t)
    # f(x^k, t) = x_0 + sum_{i=0}^{t} (TF(x^k, t) - x^k)

    # Anderson acceleration:
    # memory: the number of previous steps to remember
    # history_buffer = [x^k-M, x^k-M+1, ..., x^k-1]
    # x^k+1 = f(x^k) - G^k R^k where G^k = (I - R^k)^-1
    # x^k = [x^k[0], x^k[1], ..., x^k[L]]

    emb = model.transformer.embedding(idx)
    x = model.transformer.drop(emb)
    _, seq_len, hidden_dim = x.size()

    num_loops = model.num_loop

    timesteps = torch.arange(0, num_loops, device=x.device)
    parallel = parallel  # min(parallel, len(timesteps))
    begin_idx = 0
    end_idx = begin_idx + parallel

    latents_time_evolution_buffer = torch.stack([x] * (len(timesteps) + 1))
    residual_memory = None

    memory_indexes = torch.zeros(num_loops + 1, device=x.device, dtype=torch.long)

    logits_list = []
    tolerance = 1.0
    stop_criteria = tolerance + 1.0
    n = 0
    while stop_criteria > tolerance and len(logits_list) < max_steps:
        print("step", n)
        n += 1
        # print(begin_idx, end_idx)
        parallel_len = end_idx - begin_idx
        block_latents = latents_time_evolution_buffer[begin_idx:end_idx]  # x^k
        t_vec = timesteps[begin_idx:end_idx]

        model_output = torch.zeros_like(block_latents)
        for _i, _t in enumerate(t_vec):
            with torch.no_grad():
                print(_i)
                print(block_latents[_i][0, stop_idx, :50])
                model_output[_i] = model.f(block_latents[_i], _t.item()) - block_latents[_i]
                print(model_output[_i][0, stop_idx, :50])
        delta = model_output.reshape(parallel_len, 1, seq_len, hidden_dim)
        cumulative_delta = torch.cumsum(delta, dim=0)
        block_latents_new = latents_time_evolution_buffer[begin_idx][None,] + cumulative_delta  # f(x^k)
        
        # debug
        last_latent = block_latents_new[-1]
        print(last_latent[0, stop_idx, :50])
        _x = model.transformer.ln_f(last_latent)
        _logits = model.lm_head(_x)
        _pred = torch.argmax(_logits[0, stop_idx])
        print("pred", _pred)


        # f(x^k) - x^k
        cur_error_vec = (
            block_latents_new - latents_time_evolution_buffer[begin_idx + 1 : end_idx + 1]
        )  # [parallel_len, 1, seq_len, hidden_dim]
        ho_residual = cur_error_vec.to(torch.float64) # [parallel_len, 1, seq_len, hidden_dim] # R
        cur_error = torch.linalg.vector_norm(cur_error_vec[:, 0, stop_idx, :], dim=-1)  # [parallel_len]
        print(cur_error)

        # Anderson acceleration
        Gf = torch.zeros_like(ho_residual)

        if residual_memory is None:
            residual_memory = torch.zeros(1, num_loops + 1, 1, seq_len, hidden_dim, device=x.device, dtype = torch.float64)
            samples_memory = torch.zeros(1, num_loops + 1, 1, seq_len, hidden_dim, device=x.device, dtype = torch.float64)
            residual_memory[0, t_vec] = ho_residual
            samples_memory[0, t_vec] = latents_time_evolution_buffer[begin_idx + 1 : end_idx + 1].to(torch.float64)
            memory_indexes[t_vec] = torch.clamp(memory_indexes[t_vec] + 1, max=memory_size)
        else:
            padded_residual = torch.zeros(1, num_loops + 1, 1, seq_len, hidden_dim, device=x.device, dtype = torch.float64)
            padded_samples = torch.zeros(1, num_loops + 1, 1, seq_len, hidden_dim, device=x.device, dtype = torch.float64)
            padded_residual[0, t_vec] = ho_residual
            padded_samples[0, t_vec] = latents_time_evolution_buffer[begin_idx + 1 : end_idx + 1].to(torch.float64)

            residual_memory = torch.cat([residual_memory, padded_residual], dim=0)
            samples_memory = torch.cat([samples_memory, padded_samples], dim=0)

            residual_memory = residual_memory[-memory_size:]
            samples_memory = samples_memory[-memory_size:]
            memory_indexes[t_vec] = torch.clamp(memory_indexes[t_vec] + 1, max=memory_size)

            residual_diff = residual_memory[1:] - residual_memory[:-1]
            sample_diff = samples_memory[1:] - samples_memory[:-1]

            residual_diff_t = residual_diff[:, t_vec, :, :, :]
            sample_diff_t = sample_diff[:, t_vec, :, :, :]  #

            # print("residual_diff", residual_diff.shape, "sample_diff", sample_diff.shape)
    
            use_memory = memory_indexes[t_vec]  # [parallel_len]
            use_memory_max = use_memory.max() # memory_size  # min(memory_size, len(t_vec))
            sample_diff_mat = sample_diff_t[:use_memory_max, :, :, :, :]  # [m_k, parallel_len, 1, seq_len, hidden_dim]
            res_diff_mat = residual_diff_t[:use_memory_max, :, :, :, :]  # [m_k, parallel_len, 1, seq_len, hidden_dim]
            flip_res_diff_mat = torch.flip(res_diff_mat, dims=[1])
            B = torch.einsum("ijklm,pjklm->ipj", flip_res_diff_mat, flip_res_diff_mat)  # [m_k, m_k, parallel_len]
            B = torch.flip(torch.cumsum(B, dim=2), dims=[2])  #
            B = B.permute(2, 0, 1)  # [parallel_len, m_k, m_k]

            flip_ho_residual = torch.flip(ho_residual, dims=[0])
            d = torch.einsum("ijklm,jklm->ji", flip_res_diff_mat, flip_ho_residual)  # [m_k, parallel_len]
            d = torch.flip(torch.cumsum(d, dim=0), dims=[0])  # [parallel_len, m_k]
            
            ind = torch.argmax((cur_error > tolerance).int()).item() + 1
            # print(torch.arange(d.shape[1], device = d.device).unsqueeze(0).shape) # [1, m_k]
            # indices = torch.arange(d.shape[1], device = d.device).unsqueeze(0).expand(d.shape)
            indices = torch.arange(ind, device = d.device).unsqueeze(0).expand(d.shape[0], ind)
            mask_d = indices 
            #print(indices)
            #print(indices.shape, mask_d.shape)
            d[mask_d] = 0
            #print(d)

            #mask_B = mask_d.unsqueeze(2) | mask_d.unsqueeze(1)
            #B[mask_B] = 0
            #print(B.shape, d.shape)
            #B = B + 1e3 * torch.eye(use_memory_max, device=B.device, dtype = torch.float64).unsqueeze(0) # [parallel_len, m_k, m_k]
            #print(B.shape, d.shape)

            if use_memory_max == 1:
                solve_d = d / B.squeeze(-1)
            else:
                solve_d = torch.linalg.solve(B, d)  # [parallel_len, m_k]

            #print(solve_d)

            A = sample_diff_mat + res_diff_mat  # [m_k, parallel_len, 1, seq_len, hidden_dim]

            # print("debug", A.shape, solve_d.shape) # [m_k, parallel_len, 1, seq_len, hidden_dim], [parallel_len, m_k]
            Gf_flat = torch.einsum("ijklm,ji->jklm", A, solve_d)  # [parallel_len, 1, seq_len, hidden_dim]
            Gf = Gf_flat
    
        # update x^k
        latents_time_evolution_buffer[begin_idx + 1 : end_idx + 1] = block_latents_new - Gf

        # post-processing

        # stop criterion: ||x^k - x^k-1||_2
        # cur_error = torch.linalg.vector_norm(cur_error_vec[:, 0, stop_idx, :], dim=-1)  # [parallel_len]
        # print(cur_error)
        stop_criteria = cur_error[-1]

        # approximation logits at k-th step
        with torch.no_grad():
            logits = model.lm_head(model.transformer.ln_f(block_latents_new[-1]))
            logits_list.append(logits)

    return torch.stack(logits_list, dim=0)


In [204]:
Loop = 100


class Args:
    dmodel = 256
    drop = 0.0
    num_layer = 1
    num_loop = Loop
    rpe = False
    head = 4
    num_range = 180
    chain = False
    file = "../data/ED/60"
    debug = False
    vocab = 211
    maxlen = 127

import torch
from torch import nn
import sys
sys.path.append("../")
from model import Embedding, CausalSelfAttention, NewGELU
from tasks.ED.dataloader import MyDataSet

class LoopedGPT(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.num_loop = args.num_loop
        self.transformer = nn.ModuleDict(
            dict(
                embedding=Embedding(
                    d_model=args.dmodel,
                    vocab_size=args.vocab,
                    maxlen=args.maxlen,
                    rpe=args.rpe,
                ),
                drop=nn.Dropout(args.drop),
                attn=CausalSelfAttention(
                    d_model=args.dmodel,
                    nhead=args.head,
                    drop=args.drop,
                    maxlen=args.maxlen,
                    rpe=args.rpe,
                ),
                mlp=nn.Sequential(
                    nn.Linear(args.dmodel, 4 * args.dmodel),
                    NewGELU(),
                    nn.Linear(4 * args.dmodel, args.dmodel),
                ),
                norm1=nn.LayerNorm(args.dmodel),
                norm2=nn.LayerNorm(args.dmodel),
                ln_f=nn.LayerNorm(args.dmodel),
            )
        )
        self.lm_head = nn.Linear(args.dmodel, args.vocab, bias=True)

        self.resweight = nn.ParameterList(
            [nn.Parameter(torch.Tensor([0])) for _ in range(self.num_loop)]
        )

    def f_attn(self, x, idx):
        gate_msa = self.resweight[idx]
        norm1 = self.transformer.norm1
        attn = self.transformer.attn
        return gate_msa * attn(norm1(x))

    def f_mlp(self, x, idx):
        gate_mlp = self.resweight[idx]
        norm2 = self.transformer.norm2
        mlp = self.transformer.mlp
        return gate_mlp * mlp(norm2(x))

    def f(self, x, idx):
        x = x + self.f_attn(x, idx)
        x = x + self.f_mlp(x, idx)
        return x

    def forward(self, idx, ys=None, target_pos=None):
        emb = self.transformer.embedding(idx)
        x = self.transformer.drop(emb)
        for l in range(self.num_loop):
            print(l)
            skip = x
            print(skip[0, target_pos, :50])
            x = self.f(x, l)
            print((x-skip)[0, target_pos, :50])
            _x = self.transformer.ln_f(x)
            _logits = self.lm_head(_x)
            _pred = torch.argmax(_logits[0, target_pos])
            print("pred", _pred)


        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)  # [b, t, vocab]
        return logits


In [207]:
args = Args()
dataset = MyDataSet(args, 1)

model = LoopedGPT(args)
# model = model.cuda()
ckpt = torch.load("../output/ED_60_Loop_100/latest.pt", map_location="cpu")
model.load_state_dict(ckpt, strict=True)

import random

i = random.randint(0, len(dataset))
x, y, _ = dataset[i]
# x, y = x.cuda(), y.cuda()
x = x.unsqueeze(0)  # [1, 87]
print(y)

# print(x, y)
# print the index of the max in y
target_pos = y.argmax()
# print(y.argmax())

logits = model(x, target_pos=target_pos)
print(logits.shape)  # torch.Size([1, 126, 211])
truth = y[target_pos]
predict = torch.argmax(logits[0, target_pos])
print(truth, predict)

tensor([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 83,  2,  0,  0,  0,  0,  0])
0
tensor([-0.3198,  0.1604, -0.4218, -0.7188,  0.1371, -0.6879, -0.4626,  0.2135,
        -0.2972, -0.6037, -0.1988, -0.4068,  0.0749,  0.1185, -0.3799, -0.3385,
        -1.4910, -1.2082, -0.5119,  1.0380, -0.6013,  0.1269, -0.5872, -0.6043,
         0.2652, -0.0485,  0.7073, -1.3930, -0.5959,  0.1534, -0.1459, -1.2393,
        -0.8177,  0.6868, -0.1189,  0.1979,  0.5661,  0.4826, -0.7036,  0.3651,
        -0.5592, -0.9299, -0.3896, -0

In [208]:
parallel = 25
memory_size = 10
logits = parallel_inference(model, x, parallel, memory_size, max_steps=20, stop_idx=target_pos)
print(logits.shape)  # torch.Size([100, 1, 126, 211])
predict = torch.argmax(logits[-1, 0, target_pos])
print(truth, predict)

step 0
0
tensor([-0.3198,  0.1604, -0.4218, -0.7188,  0.1371, -0.6879, -0.4626,  0.2135,
        -0.2972, -0.6037, -0.1988, -0.4068,  0.0749,  0.1185, -0.3799, -0.3385,
        -1.4910, -1.2082, -0.5119,  1.0380, -0.6013,  0.1269, -0.5872, -0.6043,
         0.2652, -0.0485,  0.7073, -1.3930, -0.5959,  0.1534, -0.1459, -1.2393,
        -0.8177,  0.6868, -0.1189,  0.1979,  0.5661,  0.4826, -0.7036,  0.3651,
        -0.5592, -0.9299, -0.3896, -0.5066, -1.0134,  0.6462, -0.8383, -0.8899,
         0.3035, -0.2724], requires_grad=True)
tensor([ 0.2687,  0.2183,  0.0329,  0.2841,  0.4149,  0.1796,  0.3410, -0.3069,
         0.2048,  0.1542, -0.0493, -0.2885,  0.0365,  0.3215, -0.1007,  0.3511,
        -0.0065,  0.1490,  0.3667, -0.2207,  0.0712, -0.1037,  0.1496,  0.3472,
        -0.0904,  0.3820, -0.2083,  0.3382, -0.6243, -0.0442, -0.1581,  0.3133,
        -0.1993, -0.0015, -0.1846, -0.0045, -0.2321,  0.0192,  0.0102, -0.1688,
        -0.2925,  0.4438, -0.3283, -0.2798,  0.7047, -0.3128,  0