In [10]:
from torch.profiler import profile, record_function, ProfilerActivity
import torch.nn as nn
import torch
from ALPackage.transformer import Transformer, TransformerAL
from utils import get_word_vector, get_nlp_data, set_device
import threading

In [11]:
class Args:
    def __init__(self, dataset, max_len, min_len = 1):
        self.dataset = dataset
        self.class_num = None
        self.batch_size = 256        
        self.max_len = max_len
        self.min_len = min_len
        self.vocab_size = 30000
        self.pretrained_embedding = None
        self.embedding_dim = 300
        self.x_hid = 256
        self.y_hid = 128
        self.n_heads = 6
        self.n_layers = 1        
        self.dropout = 0.3
        self.lr = 0.00025
        self.epochs = 10
        self.act = nn.Tanh()

In [12]:
dataset = 'dbpedia_14'
args = Args(dataset, 128)
train_loader, test_loader, class_num, vocab = get_nlp_data(args)
args.class_num = class_num
args.pretrained_embedding = get_word_vector(vocab, 'glove')
x, y = next(iter(train_loader))



Found cached dataset dbpedia_14 (/home/u3933826/.cache/huggingface/datasets/dbpedia_14/dbpedia_14/2.0.0/01dab9e10d969eadcdbc918be5a09c9190a24caeae33b10eee8f367a1e3f1f0c)
Found cached dataset dbpedia_14 (/home/u3933826/.cache/huggingface/datasets/dbpedia_14/dbpedia_14/2.0.0/01dab9e10d969eadcdbc918be5a09c9190a24caeae33b10eee8f367a1e3f1f0c)


Original Data: 560000
Valid Data: 559967
total count words 887879
vocab size 30000


loading glove vocabs...: 100%|██████████| 400000/400000 [00:05<00:00, 70505.63it/s]


found 28354 words in glove


In [13]:
model = Transformer(args.vocab_size, args.embedding_dim, args.x_hid, args.class_num, args.n_heads, args.n_layers, args.dropout, args.pretrained_embedding)
set_device(model, ["cuda:0","cuda:1","cuda:2","cuda:3","cuda:3"])
optimizer = torch.optim.Adam(model.parameters(), args.lr)
loss_fn = nn.NLLLoss()
for _ in range(5):
    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
        with record_function("forward"):
            out = model(x)
        with record_function("backward"):
            loss = loss_fn(out, y.to('cuda:3'))
            loss.backward()
        with record_function("update"):
            optimizer.step()
            optimizer.zero_grad()
            
    prof.export_chrome_trace(f"{dataset}_BP.json")

STAGE:2023-03-28 09:23:50 108:108 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-03-28 09:23:51 108:108 ActivityProfilerController.cpp:300] Completed Stage: Collection
STAGE:2023-03-28 09:23:51 108:108 output_json.cpp:417] Completed Stage: Post Processing
STAGE:2023-03-28 09:23:51 108:108 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-03-28 09:23:51 108:108 ActivityProfilerController.cpp:300] Completed Stage: Collection
STAGE:2023-03-28 09:23:52 108:108 output_json.cpp:417] Completed Stage: Post Processing
STAGE:2023-03-28 09:23:52 108:108 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-03-28 09:23:52 108:108 ActivityProfilerController.cpp:300] Completed Stage: Collection
STAGE:2023-03-28 09:23:52 108:108 output_json.cpp:417] Completed Stage: Post Processing
STAGE:2023-03-28 09:23:52 108:108 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-03-28 09:23:52 108:108 ActivityProfilerController.c

In [14]:
model = TransformerAL(args.vocab_size, args.embedding_dim, args.x_hid, args.class_num, args.y_hid, args.act, args.lr, args.n_heads, args.n_layers, args.dropout, args.pretrained_embedding)
set_device(model, ['cuda:0', 'cuda:1', 'cuda:2', 'cuda:3'])

In [15]:
for _ in range(5):
    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
        with record_function("forward"):
            model(x, torch.nn.functional.one_hot(y, args.class_num).float())
        with record_function("backward"):
            model.backward()
        with record_function("update"):
            model.update()
    prof.export_chrome_trace(f"{dataset}_AL.json")

STAGE:2023-03-28 09:24:15 108:108 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-03-28 09:24:15 108:108 ActivityProfilerController.cpp:300] Completed Stage: Collection
STAGE:2023-03-28 09:24:16 108:108 output_json.cpp:417] Completed Stage: Post Processing
STAGE:2023-03-28 09:24:16 108:108 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-03-28 09:24:16 108:108 ActivityProfilerController.cpp:300] Completed Stage: Collection
STAGE:2023-03-28 09:24:16 108:108 output_json.cpp:417] Completed Stage: Post Processing
STAGE:2023-03-28 09:24:16 108:108 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-03-28 09:24:16 108:108 ActivityProfilerController.cpp:300] Completed Stage: Collection
STAGE:2023-03-28 09:24:16 108:108 output_json.cpp:417] Completed Stage: Post Processing
STAGE:2023-03-28 09:24:16 108:108 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-03-28 09:24:16 108:108 ActivityProfilerController.c

In [16]:
for _ in range(5):
    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
        with record_function("forward"):
            model(x, torch.nn.functional.one_hot(y, args.class_num).float())
        model.record_thread_backward_and_update()
    prof.export_chrome_trace(f"{dataset}_AL_M.json")

STAGE:2023-03-28 09:24:21 108:108 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-03-28 09:24:21 108:108 ActivityProfilerController.cpp:300] Completed Stage: Collection
STAGE:2023-03-28 09:24:21 108:108 output_json.cpp:417] Completed Stage: Post Processing
STAGE:2023-03-28 09:24:21 108:108 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-03-28 09:24:21 108:108 ActivityProfilerController.cpp:300] Completed Stage: Collection
STAGE:2023-03-28 09:24:22 108:108 output_json.cpp:417] Completed Stage: Post Processing
STAGE:2023-03-28 09:24:22 108:108 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-03-28 09:24:22 108:108 ActivityProfilerController.cpp:300] Completed Stage: Collection
STAGE:2023-03-28 09:24:22 108:108 output_json.cpp:417] Completed Stage: Post Processing
STAGE:2023-03-28 09:24:22 108:108 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-03-28 09:24:22 108:108 ActivityProfilerController.c

In [17]:
for _ in range(5):
    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
        model.record_thread_forward_backward_and_update(x, torch.nn.functional.one_hot(y, args.class_num).float())
    prof.export_chrome_trace(f"{dataset}_AL_P.json")

STAGE:2023-03-28 09:24:25 108:108 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-03-28 09:24:25 108:108 ActivityProfilerController.cpp:300] Completed Stage: Collection
STAGE:2023-03-28 09:24:25 108:108 output_json.cpp:417] Completed Stage: Post Processing
STAGE:2023-03-28 09:24:25 108:108 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-03-28 09:24:25 108:108 ActivityProfilerController.cpp:300] Completed Stage: Collection
STAGE:2023-03-28 09:24:26 108:108 output_json.cpp:417] Completed Stage: Post Processing
STAGE:2023-03-28 09:24:26 108:108 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-03-28 09:24:26 108:108 ActivityProfilerController.cpp:300] Completed Stage: Collection
STAGE:2023-03-28 09:24:26 108:108 output_json.cpp:417] Completed Stage: Post Processing
STAGE:2023-03-28 09:24:26 108:108 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-03-28 09:24:26 108:108 ActivityProfilerController.c

In [21]:
from tqdm import tqdm
import numpy as np
import time

In [27]:
def timer(dataset, max_len):
    args = Args(dataset, max_len)
    train_loader, test_loader, class_num, vocab = get_nlp_data(args)
    args.class_num = class_num
    args.pretrained_embedding = get_word_vector(vocab, 'glove')
    model = TransformerAL(args.vocab_size, args.embedding_dim, args.x_hid, args.class_num, args.y_hid, args.act, args.lr, args.n_heads, args.n_layers, args.dropout, args.pretrained_embedding)
    set_device(model, ['cuda:0', 'cuda:1', 'cuda:2', 'cuda:3'])
    
    alm = []
    for _ in range(10):
        torch.cuda.synchronize()
        start = time.time()
        for inputs, labels in tqdm(train_loader):
            y = torch.nn.functional.one_hot(labels, args.class_num).float()
            model(inputs, y)
            model.thread_backward_and_update()
        torch.cuda.synchronize()
        end = time.time()
        alm.append(end-start)   
        
    al = []
    for _ in range(10):
        torch.cuda.synchronize()
        start = time.time()
        for inputs, labels in tqdm(train_loader):
            y = torch.nn.functional.one_hot(labels, args.class_num).float()
            model(inputs, y)
            model.backward()
            model.update()
        torch.cuda.synchronize()
        end = time.time()
        al.append(end-start) 
        
    alp = []
    for _ in range(10):
        torch.cuda.synchronize()
        start = time.time()
        for inputs, labels in tqdm(train_loader):
            y = torch.nn.functional.one_hot(labels, args.class_num).float()
            model.thread_forward_backward_and_update(inputs, y)
        torch.cuda.synchronize()
        end = time.time()
        alp.append(end-start) 
        
    print(f"{dataset}_AL:{np.mean(al):.4f} ± {np.std(al):.4f}")
    print(f"{dataset}_ALM:{np.mean(alm):.4f} ± {np.std(alm):.4f}")
    print(f"{dataset}_ALP:{np.mean(alp):.4f} ± {np.std(alp):.4f}")

In [28]:
timer('ag_news', 128)



Found cached dataset ag_news (/home/u3933826/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)
Found cached dataset ag_news (/home/u3933826/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)


Original Data: 120000
Valid Data: 120000
total count words 102019
vocab size 30000


loading glove vocabs...: 100%|██████████| 400000/400000 [00:05<00:00, 71211.56it/s]


found 26754 words in glove


100%|██████████| 469/469 [00:15<00:00, 29.37it/s]
100%|██████████| 469/469 [00:15<00:00, 29.35it/s]
100%|██████████| 469/469 [00:16<00:00, 29.27it/s]
100%|██████████| 469/469 [00:16<00:00, 29.10it/s]
100%|██████████| 469/469 [00:15<00:00, 29.47it/s]
100%|██████████| 469/469 [00:15<00:00, 29.62it/s]
100%|██████████| 469/469 [00:16<00:00, 28.33it/s]
100%|██████████| 469/469 [00:17<00:00, 27.52it/s]
100%|██████████| 469/469 [00:16<00:00, 28.57it/s]
100%|██████████| 469/469 [00:15<00:00, 29.68it/s]
100%|██████████| 469/469 [00:20<00:00, 22.45it/s]
100%|██████████| 469/469 [00:20<00:00, 22.86it/s]
100%|██████████| 469/469 [00:20<00:00, 23.24it/s]
100%|██████████| 469/469 [00:20<00:00, 23.19it/s]
100%|██████████| 469/469 [00:20<00:00, 23.00it/s]
100%|██████████| 469/469 [00:21<00:00, 21.49it/s]
100%|██████████| 469/469 [00:20<00:00, 22.65it/s]
100%|██████████| 469/469 [00:20<00:00, 22.88it/s]
100%|██████████| 469/469 [00:20<00:00, 22.83it/s]
100%|██████████| 469/469 [00:20<00:00, 22.60it/s]


ag_news_AL:20.6552 ± 0.4432
ag_news_ALM:16.1679 ± 0.3712
ag_news_ALP:14.2381 ± 0.0741





In [29]:
timer('imdb', 128)

Original Data: 40000
Valid Data: 40000
total count words 193263
vocab size 30000


loading glove vocabs...: 100%|██████████| 400000/400000 [00:05<00:00, 70021.66it/s]


found 27875 words in glove


100%|██████████| 157/157 [00:09<00:00, 16.76it/s]
100%|██████████| 157/157 [00:09<00:00, 17.25it/s]
100%|██████████| 157/157 [00:09<00:00, 17.42it/s]
100%|██████████| 157/157 [00:09<00:00, 17.32it/s]
100%|██████████| 157/157 [00:09<00:00, 17.07it/s]
100%|██████████| 157/157 [00:09<00:00, 16.74it/s]
100%|██████████| 157/157 [00:09<00:00, 17.39it/s]
100%|██████████| 157/157 [00:09<00:00, 17.38it/s]
100%|██████████| 157/157 [00:09<00:00, 16.80it/s]
100%|██████████| 157/157 [00:09<00:00, 17.44it/s]
100%|██████████| 157/157 [00:10<00:00, 15.22it/s]
100%|██████████| 157/157 [00:09<00:00, 16.57it/s]
100%|██████████| 157/157 [00:09<00:00, 16.15it/s]
100%|██████████| 157/157 [00:10<00:00, 15.55it/s]
100%|██████████| 157/157 [00:09<00:00, 16.57it/s]
100%|██████████| 157/157 [00:10<00:00, 15.34it/s]
100%|██████████| 157/157 [00:10<00:00, 15.49it/s]
100%|██████████| 157/157 [00:09<00:00, 16.02it/s]
100%|██████████| 157/157 [00:09<00:00, 16.13it/s]
100%|██████████| 157/157 [00:09<00:00, 16.25it/s]


imdb_AL:9.8670 ± 0.2925
imdb_ALM:9.1547 ± 0.1477
imdb_ALP:9.9736 ± 0.2126





In [None]:
dataset = 'dbpedia_14'
args = Args(dataset, 128)
train_loader, test_loader, class_num, vocab = get_nlp_data(args)
args.class_num = class_num
args.pretrained_embedding = get_word_vector(vocab, 'glove')
model = TransformerAL(args.vocab_size, args.embedding_dim, args.x_hid, args.class_num, args.y_hid, args.act, args.lr, args.n_heads, args.n_layers, args.dropout, args.pretrained_embedding)
set_device(model, ['cuda:0', 'cuda:1', 'cuda:2', 'cuda:3'])

In [18]:
alm = []
for _ in range(10):
    torch.cuda.synchronize()
    start = time.time()
    for inputs, labels in tqdm(train_loader):
        y = torch.nn.functional.one_hot(labels, args.class_num).float()
        model(inputs, y)
        model.thread_backward_and_update()
    torch.cuda.synchronize()
    end = time.time()
    alm.append(end-start)      

100%|██████████| 2188/2188 [01:22<00:00, 26.53it/s]
100%|██████████| 2188/2188 [01:20<00:00, 27.18it/s]
100%|██████████| 2188/2188 [01:21<00:00, 26.71it/s]
100%|██████████| 2188/2188 [01:18<00:00, 27.75it/s]
100%|██████████| 2188/2188 [01:19<00:00, 27.39it/s]
100%|██████████| 2188/2188 [01:23<00:00, 26.12it/s]
100%|██████████| 2188/2188 [01:20<00:00, 27.10it/s]
100%|██████████| 2188/2188 [01:22<00:00, 26.63it/s]
100%|██████████| 2188/2188 [01:19<00:00, 27.51it/s]
100%|██████████| 2188/2188 [01:21<00:00, 27.01it/s]


In [19]:
al = []
for _ in range(10):
    torch.cuda.synchronize()
    start = time.time()
    for inputs, labels in tqdm(train_loader):
        y = torch.nn.functional.one_hot(labels, args.class_num).float()
        model(inputs, y)
        model.backward()
        model.update()
    torch.cuda.synchronize()
    end = time.time()
    al.append(end-start)      

100%|██████████| 2188/2188 [01:38<00:00, 22.12it/s]
100%|██████████| 2188/2188 [01:49<00:00, 20.00it/s]
100%|██████████| 2188/2188 [01:42<00:00, 21.40it/s]
100%|██████████| 2188/2188 [01:34<00:00, 23.22it/s]
100%|██████████| 2188/2188 [01:33<00:00, 23.36it/s]
100%|██████████| 2188/2188 [01:32<00:00, 23.59it/s]
100%|██████████| 2188/2188 [01:46<00:00, 20.59it/s]
100%|██████████| 2188/2188 [01:42<00:00, 21.34it/s]
100%|██████████| 2188/2188 [01:46<00:00, 20.53it/s]
100%|██████████| 2188/2188 [01:49<00:00, 19.98it/s]


In [22]:
print(f"{dataset}_AL:{np.mean(al):.4f} ± {np.std(al):.4f}")
print(f"{dataset}_ALM:{np.mean(alm):.4f} ± {np.std(alm):.4f}")

dbpedia_14_AL:101.6177 ± 6.1217
dbpedia_14_ALM:81.0853 ± 1.4269
