-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[unit test] refactor test tensor (#1005)
* polish test_gpt * update op unit tests * update test model
- Loading branch information
Showing
8 changed files
with
143 additions
and
290 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from . import repeated_computed_layer, resnet, nested_model, bert, no_leaf_module, simple_net | ||
from . import repeated_computed_layer, resnet, nested_model, bert, no_leaf_module, simple_net, gpt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import torch | ||
import torch.nn as nn | ||
from .registry import non_distributed_component_funcs | ||
from transformers import GPT2Config, GPT2LMHeadModel | ||
from .utils.dummy_data_generator import DummyDataGenerator | ||
from colossalai.utils.cuda import get_current_device | ||
|
||
|
||
class DummyDataLoader(DummyDataGenerator): | ||
vocab_size = 50304 | ||
batch_size = 4 | ||
seq_len = 1024 | ||
|
||
def generate(self): | ||
input_ids = torch.randint(0, | ||
DummyDataLoader.vocab_size, (DummyDataLoader.batch_size, DummyDataLoader.seq_len), | ||
device=get_current_device()) | ||
attention_mask = torch.ones_like(input_ids) | ||
return input_ids, attention_mask | ||
|
||
|
||
class GPTLMModel(nn.Module): | ||
|
||
def __init__(self, | ||
hidden_size=768, | ||
num_layers=12, | ||
num_attention_heads=12, | ||
max_seq_len=1024, | ||
vocab_size=50304, | ||
checkpoint=False): | ||
super().__init__() | ||
self.checkpoint = checkpoint | ||
self.model = GPT2LMHeadModel( | ||
GPT2Config(n_embd=hidden_size, | ||
n_layer=num_layers, | ||
n_head=num_attention_heads, | ||
n_positions=max_seq_len, | ||
n_ctx=max_seq_len, | ||
vocab_size=vocab_size, | ||
resid_pdrop=0.0, | ||
embd_pdrop=0.0, | ||
attn_pdrop=0.0)) | ||
if checkpoint: | ||
self.model.gradient_checkpointing_enable() | ||
|
||
def forward(self, input_ids, attention_mask): | ||
# Only return lm_logits | ||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] | ||
|
||
|
||
def gpt2_s(checkpoint=True): | ||
return GPTLMModel(checkpoint=checkpoint) | ||
|
||
|
||
def gpt2_m(checkpoint=True): | ||
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) | ||
|
||
|
||
class GPTLMLoss(nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.loss_fn = nn.CrossEntropyLoss() | ||
|
||
def forward(self, logits, labels): | ||
shift_logits = logits[..., :-1, :].contiguous() | ||
shift_labels = labels[..., 1:].contiguous() | ||
# Flatten the tokens | ||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | ||
|
||
|
||
@non_distributed_component_funcs.register(name='gpt2') | ||
def get_training_components(): | ||
|
||
trainloader = DummyDataLoader() | ||
testloader = DummyDataLoader() | ||
|
||
criterion = GPTLMLoss() | ||
return gpt2_s, trainloader, testloader, torch.optim.Adam, criterion |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.