Skip to content

Commit

Permalink
implement abstract methods for JiantTransformerModels
Browse files Browse the repository at this point in the history
  • Loading branch information
jeswan committed Mar 19, 2021
1 parent 33faa25 commit ad49c64
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions jiant/proj/main/modeling/primary.py
Expand Up @@ -423,6 +423,13 @@ def get_feat_spec(self, max_seq_length):
sep_token_extra=False,
)

@classmethod
def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization):
raise NotImplementedError()

def get_mlm_weights_dict(self, weights_dict):
raise NotImplementedError()


@JiantTransformersModelFactory.register(ModelArchitectures.BART)
class JiantBartModel(JiantTransformersModel):
Expand Down Expand Up @@ -475,12 +482,19 @@ def __call__(self, encoder, input_ids, input_mask):
pooled = unpooled[batch_idx, slen - input_ids.eq(encoder.config.pad_token_id).sum(1) - 1]
return JiantModelOutput(pooled=pooled, unpooled=unpooled, other=other)

def get_mlm_weights_dict(self, weights_dict):
raise NotImplementedError()


@JiantTransformersModelFactory.register(ModelArchitectures.MBART)
class JiantMBartModel(JiantBartModel):
def __init__(self, baseObject):
super().__init__(baseObject)

@classmethod
def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization):
raise NotImplementedError()

def get_feat_spec(self, max_seq_length):
# mBART is weird
# token 0 = '<s>' which is the cls_token
Expand All @@ -498,3 +512,6 @@ def get_feat_spec(self, max_seq_length):
sequence_b_segment_id=0, # mBART has no token_type_ids
sep_token_extra=True,
)

def get_mlm_weights_dict(self, weights_dict):
raise NotImplementedError()

0 comments on commit ad49c64

Please sign in to comment.