Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to task model/head factories instead of embedded if-else statements #1268

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
141 changes: 128 additions & 13 deletions jiant/proj/main/modeling/heads.py
@@ -1,29 +1,93 @@
from __future__ import annotations

import abc

import torch
import torch.nn as nn

import transformers

from jiant.ext.allennlp import SelfAttentiveSpanExtractor
from jiant.shared.model_resolution import ModelArchitectures
from jiant.tasks.core import TaskTypes
from typing import Callable
from typing import List

"""
In HuggingFace/others, these heads differ slightly across different encoder models.
We're going to abstract away from that and just choose one implementation.
"""


class JiantHeadFactory:
jeswan marked this conversation as resolved.
Show resolved Hide resolved
"""This factory is used to create task-specific heads for the supported Transformer encoders.

Attributes:
registry (dict): Dynamic registry mapping task types to task heads
"""

registry = {}

@classmethod
def register(cls, task_type_list: List[TaskTypes]) -> Callable:
"""Register each TaskType in task_type_list as a key mapping to a BaseHead task head

Args:
task_type_list (List[TaskType]): List of TaskTypes that are associated to a
BaseHead task head

Returns:
Callable: inner_wrapper() wrapping task head constructor or task head factory
"""

def inner_wrapper(wrapped_class: BaseHead) -> Callable:
"""Summary

Args:
wrapped_class (BaseHead): Task head class

Returns:
Callable: Task head constructor or factory
"""
for task_type in task_type_list:
assert task_type not in cls.registry
cls.registry[task_type] = wrapped_class
return wrapped_class

return inner_wrapper

def __call__(self, task, **kwargs) -> BaseHead:
"""Summary

Args:
task (Task): A task head will be created based on the task type
**kwargs: Arguments required for task head initialization

Returns:
BaseHead: Initialized task head
"""
head_class = self.registry[task.TASK_TYPE]
head = head_class(task, **kwargs)
return head


class BaseHead(nn.Module, metaclass=abc.ABCMeta):
pass
"""Absract class for task heads"""

@abc.abstractmethod
def __init__(self):
super().__init__()


@JiantHeadFactory.register([TaskTypes.CLASSIFICATION])
class ClassificationHead(BaseHead):
def __init__(self, hidden_size, hidden_dropout_prob, num_labels):
def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs):
"""From RobertaClassificationHead"""
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.out_proj = nn.Linear(hidden_size, num_labels)
self.num_labels = num_labels
self.out_proj = nn.Linear(hidden_size, task.num_labels)
self.num_labels = len(task.LABELS)

def forward(self, pooled):
x = self.dropout(pooled)
Expand All @@ -34,8 +98,9 @@ def forward(self, pooled):
return logits


@JiantHeadFactory.register([TaskTypes.REGRESSION, TaskTypes.MULTIPLE_CHOICE])
class RegressionHead(BaseHead):
def __init__(self, hidden_size, hidden_dropout_prob):
def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs):
"""From RobertaClassificationHead"""
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
Expand All @@ -51,12 +116,13 @@ def forward(self, pooled):
return scores


@JiantHeadFactory.register([TaskTypes.SPAN_COMPARISON_CLASSIFICATION])
class SpanComparisonHead(BaseHead):
def __init__(self, hidden_size, hidden_dropout_prob, num_spans, num_labels):
def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs):
"""From RobertaForSpanComparisonClassification"""
super().__init__()
self.num_spans = num_spans
self.num_labels = num_labels
self.num_spans = task.num_spans
self.num_labels = len(task.LABELS)
self.hidden_size = hidden_size
self.dropout = nn.Dropout(hidden_dropout_prob)
self.span_attention_extractor = SelfAttentiveSpanExtractor(hidden_size)
Expand All @@ -70,22 +136,24 @@ def forward(self, unpooled, spans):
return logits


@JiantHeadFactory.register([TaskTypes.TAGGING])
class TokenClassificationHead(BaseHead):
def __init__(self, hidden_size, num_labels, hidden_dropout_prob):
def __init__(self, task, hidden_size, hidden_dropout_prob, **kwargs):
"""From RobertaForTokenClassification"""
super().__init__()
self.num_labels = num_labels
self.num_labels = len(task.LABELS)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.classifier = nn.Linear(hidden_size, num_labels)
self.classifier = nn.Linear(hidden_size, self.num_labels)

def forward(self, unpooled):
unpooled = self.dropout(unpooled)
logits = self.classifier(unpooled)
return logits


@JiantHeadFactory.register([TaskTypes.SQUAD_STYLE_QA])
class QAHead(BaseHead):
def __init__(self, hidden_size):
def __init__(self, task, hidden_size, **kwargs):
"""From RobertaForQuestionAnswering"""
super().__init__()
self.qa_outputs = nn.Linear(hidden_size, 2)
Expand All @@ -98,10 +166,55 @@ def forward(self, unpooled):
return logits


@JiantHeadFactory.register([TaskTypes.MASKED_LANGUAGE_MODELING])
class JiantMLMHeadFactory:
"""This factory is used to create masked language modeling (MLM) task heads.
This is required due to Transformers implementing different MLM heads for
different encoders.

Attributes:
registry (dict): Dynamic registry mapping model architectures to MLM task heads
"""

registry = {}

@classmethod
def register(cls, model_arch_list: List[ModelArchitectures]) -> Callable:
"""Registers the ModelArchitectures in model_arch_list as keys mapping to a MLMHead

Args:
model_arch_list (List[ModelArchitectures]): List of ModelArchitectures mapping to
an MLM task head.

Returns:
Callable: MLMHead class
"""

def inner_wrapper(wrapped_class: BaseMLMHead) -> Callable:
for model_arch in model_arch_list:
assert model_arch not in cls.registry
cls.registry[model_arch] = wrapped_class
return wrapped_class

return inner_wrapper

def __call__(self, task, **kwargs):
"""Summary

Args:
task (Task): Task used to initialize task head
**kwargs: Additional arguments required to initialize task head
"""
mlm_head_class = self.registry[task.TASK_TYPE]
mlm_head = mlm_head_class(task, **kwargs)
return mlm_head


class BaseMLMHead(BaseHead, metaclass=abc.ABCMeta):
pass


@JiantMLMHeadFactory.register([ModelArchitectures.BERT])
class BertMLMHead(BaseMLMHead):
"""From BertOnlyMLMHead, BertLMPredictionHead, BertPredictionHeadTransform"""

Expand All @@ -128,6 +241,7 @@ def forward(self, unpooled):
return logits


@JiantMLMHeadFactory.register([ModelArchitectures.ROBERTA, ModelArchitectures.XLM_ROBERTA])
class RobertaMLMHead(BaseMLMHead):
"""From RobertaLMHead"""

Expand Down Expand Up @@ -155,7 +269,8 @@ def forward(self, unpooled):
return logits


class AlbertMLMHead(nn.Module):
@JiantMLMHeadFactory.register([ModelArchitectures.ALBERT])
class AlbertMLMHead(BaseMLMHead):
"""From AlbertMLMHead"""

def __init__(self, hidden_size, embedding_size, vocab_size, hidden_act="gelu"):
Expand Down