Skip to content

Commit

Permalink
Use jiant transformers model wrapper instead of if-else. Use taskmode…
Browse files Browse the repository at this point in the history
…l and head factory instead of if-else.
  • Loading branch information
jeswan committed Jan 28, 2021
1 parent 723786a commit 3df2b38
Show file tree
Hide file tree
Showing 10 changed files with 436 additions and 422 deletions.
144 changes: 131 additions & 13 deletions jiant/proj/main/modeling/heads.py
@@ -1,29 +1,95 @@
from __future__ import annotations

import abc

import torch
import torch.nn as nn

import transformers

from jiant.ext.allennlp import SelfAttentiveSpanExtractor
from jiant.shared.model_resolution import ModelArchitectures
from jiant.tasks.core import TaskTypes
from typing import Callable
from typing import List

"""
In HuggingFace/others, these heads differ slightly across different encoder models.
We're going to abstract away from that and just choose one implementation.
"""


class JiantHeadFactory:

"""This factory is used to create task-specific heads for the supported Transformer encoders.
Attributes:
registry (dict): Dynamic registry mapping task types to task heads
"""

registry = {}

@classmethod
def register(cls, task_type_list: List[TaskTypes]) -> Callable:
"""Register each TaskType in task_type_list as a key mapping to a BaseHead task head
Args:
task_type_list (List[TaskType]): List of TaskTypes that are associated to a BaseHead task head
Returns:
Callable: inner_wrapper() wrapping task head constructor or task head factory
"""

def inner_wrapper(wrapped_class: BaseHead) -> Callable:
"""Summary
Args:
wrapped_class (BaseHead): Task head class
Returns:
Callable: Task head constructor or factory
"""
for task_type in task_type_list:
assert task_type not in cls.registry
cls.registry[task_type] = wrapped_class
return wrapped_class

return inner_wrapper

def __call__(self, task, **kwargs) -> BaseHead:
"""Summary
Args:
task (Task): A task head will be created based on the task type
**kwargs: Arguments required for task head initialization
Returns:
BaseHead: Initialized task head
"""
head_class = self.registry[task.TASK_TYPE]
head = head_class(task, **kwargs)
return head


class BaseHead(nn.Module, metaclass=abc.ABCMeta):
pass

"""Absract class for task heads
"""

@abc.abstractmethod
def __init__(self):
super().__init__()


@JiantHeadFactory.register([TaskTypes.CLASSIFICATION])
class ClassificationHead(BaseHead):
def __init__(self, hidden_size, hidden_dropout_prob, num_labels):
def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs):
"""From RobertaClassificationHead"""
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.out_proj = nn.Linear(hidden_size, num_labels)
self.num_labels = num_labels
self.out_proj = nn.Linear(hidden_size, task.num_labels)
self.num_labels = len(task.LABELS)

def forward(self, pooled):
x = self.dropout(pooled)
Expand All @@ -34,8 +100,9 @@ def forward(self, pooled):
return logits


@JiantHeadFactory.register([TaskTypes.REGRESSION, TaskTypes.MULTIPLE_CHOICE])
class RegressionHead(BaseHead):
def __init__(self, hidden_size, hidden_dropout_prob):
def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs):
"""From RobertaClassificationHead"""
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
Expand All @@ -51,12 +118,13 @@ def forward(self, pooled):
return scores


@JiantHeadFactory.register([TaskTypes.SPAN_COMPARISON_CLASSIFICATION])
class SpanComparisonHead(BaseHead):
def __init__(self, hidden_size, hidden_dropout_prob, num_spans, num_labels):
def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs):
"""From RobertaForSpanComparisonClassification"""
super().__init__()
self.num_spans = num_spans
self.num_labels = num_labels
self.num_spans = task.num_spans
self.num_labels = len(task.LABELS)
self.hidden_size = hidden_size
self.dropout = nn.Dropout(hidden_dropout_prob)
self.span_attention_extractor = SelfAttentiveSpanExtractor(hidden_size)
Expand All @@ -70,22 +138,24 @@ def forward(self, unpooled, spans):
return logits


@JiantHeadFactory.register([TaskTypes.TAGGING])
class TokenClassificationHead(BaseHead):
def __init__(self, hidden_size, num_labels, hidden_dropout_prob):
def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs):
"""From RobertaForTokenClassification"""
super().__init__()
self.num_labels = num_labels
self.num_labels = len(task.LABELS)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.classifier = nn.Linear(hidden_size, num_labels)
self.classifier = nn.Linear(hidden_size, self.num_labels)

def forward(self, unpooled):
unpooled = self.dropout(unpooled)
logits = self.classifier(unpooled)
return logits


@JiantHeadFactory.register([TaskTypes.SQUAD_STYLE_QA])
class QAHead(BaseHead):
def __init__(self, hidden_size):
def __init__(self, task, hidden_size, **kwargs):
"""From RobertaForQuestionAnswering"""
super().__init__()
self.qa_outputs = nn.Linear(hidden_size, 2)
Expand All @@ -98,10 +168,56 @@ def forward(self, unpooled):
return logits


@JiantHeadFactory.register([TaskTypes.MASKED_LANGUAGE_MODELING])
class JiantMLMHeadFactory:

"""This factory is used to create masked language modeling (MLM) task heads.
This is required due to Transformers implementing different MLM heads for
different encoders.
Attributes:
registry (dict): Dynamic registry mapping model architectures to MLM task heads
"""

registry = {}

@classmethod
def register(cls, model_arch_list: List[ModelArchitectures]) -> Callable:
"""Registers the ModelArchitectures in model_arch_list as keys mapping to a MLMHead
Args:
model_arch_list (List[ModelArchitectures]): List of ModelArchitectures mapping to
an MLM task head.
Returns:
Callable: MLMHead class
"""

def inner_wrapper(wrapped_class: BaseMLMHead) -> Callable:
for model_arch in model_arch_list:
assert model_arch not in cls.registry
cls.registry[model_arch] = wrapped_class
return wrapped_class

return inner_wrapper

def __call__(self, task, **kwargs):
"""Summary
Args:
task (Task): Task used to initialize task head
**kwargs: Additional arguments required to initialize task head
"""
mlm_head_class = self.registry[task.TASK_TYPE]
mlm_head = mlm_head_class(task, **kwargs)
return mlm_head


class BaseMLMHead(BaseHead, metaclass=abc.ABCMeta):
pass


@JiantMLMHeadFactory.register([ModelArchitectures.BERT])
class BertMLMHead(BaseMLMHead):
"""From BertOnlyMLMHead, BertLMPredictionHead, BertPredictionHeadTransform"""

Expand All @@ -126,6 +242,7 @@ def forward(self, unpooled):
return logits


@JiantMLMHeadFactory.register([ModelArchitectures.ROBERTA, ModelArchitectures.XLM_ROBERTA])
class RobertaMLMHead(BaseMLMHead):
"""From RobertaLMHead"""

Expand All @@ -151,7 +268,8 @@ def forward(self, unpooled):
return logits


class AlbertMLMHead(nn.Module):
@JiantMLMHeadFactory.register([ModelArchitectures.ALBERT])
class AlbertMLMHead(BaseMLMHead):
"""From AlbertMLMHead"""

def __init__(self, hidden_size, embedding_size, vocab_size, hidden_act="gelu"):
Expand Down

0 comments on commit 3df2b38

Please sign in to comment.