In [61]:
import io
from typing import List, Union

import pandas as pd
import numpy as np
from PIL import Image

from torch import Tensor
from torch.utils.data import Dataset
import torchvision.transforms as T

from google.cloud import storage

import logging

# CheXpert pathologies on original paper
pathologies = ['Atelectasis',
               'Cardiomegaly',
               'Consolidation',
               'Edema',
               'Pleural Effusion']

# Uncertainty policies on original paper
uncertainty_policies = ['U-Ignore',
                        'U-Zeros',
                        'U-Ones',
                        'U-SelfTrained',
                        'U-MultiClass']


# #####################
# # Create a Dataset ##
# #####################
class CheXpertDataset(Dataset):
    def __init__(self,
                 data_path: Union[str, None] = None,
                 uncertainty_policy: str = 'U-Ones',
                 logger: logging.Logger = logging.getLogger(__name__),
                 pathologies: List[str] = pathologies,
                 train: bool = True,
                 resize_shape: tuple = (256, 256)) -> None:
        """ Innitialize dataset and preprocess according to uncertainty policy.

        Args:
            data_path (str): Path to csv file.
            uncertainty_policy (str): Uncertainty policies compared in the
            original paper.
            Check if options are implemented. Options: 'U-Ignore', 'U-Zeros',
            'U-Ones', 'U-SelfTrained', and 'U-MultiClass'.
            logger (logging.Logger): Logger to log events during training.
            pathologies (List[str], optional): Pathologies to classify.
            Defaults to 'Atelectasis', 'Cardiomegaly', 'Consolidation',
            'Edema', and 'Pleural Effusion'.
            transform (type): method to transform image.
            train (bool): If true, returns data selected for training, if not,
            returns data selected for validation (dev set), as the CheXpert
            research group splitted.

        Returns:
            None
        """

        if not (uncertainty_policy in uncertainty_policies):
            logger.error(
                "Unknown uncertainty policy. Known policies: " +
                f"{uncertainty_policies}")
            return None

        split = 'train' if train else 'valid'
        csv_path = f"CheXpert-v1.0/{split}.csv"
        path = str(data_path) + csv_path

        self.in_cloud = False

        data = pd.DataFrame()
        try:
            data = pd.read_csv(path)
            data['Path'] = data_path + data['Path']
            logger.info("Local database found.")
        except Exception as e:
            logger.warning(f"Couldn't read csv at path {path}./n{e}")
            try:
                # Find files at gcp
                project_id = 'labshurb'

                storage_client = storage.Client(project=project_id)
                self.bucket = storage_client.bucket(
                    'chexpert_database_stanford')

                blob = self.bucket.get_blob(csv_path)
                blob.download_to_filename('tmp.csv')
                data = pd.read_csv('tmp.csv')

                self.in_cloud = True
                logger.info("Cloud database found.")

            except Exception as e_:
                logger.error(f"Couldn't reach file at path {path}./n{e_}")
                quit()

        data.set_index('Path', inplace=True)

        # data = data.loc[data['Frontal/Lateral'] == 'Frontal'].copy()
        data = data.loc[:, pathologies].copy()

        data.fillna(0, inplace=True)

        # U-Ignore
        if uncertainty_policy == uncertainty_policies[0]:
            # the only change is in the loss function, we mask the -1 labels
            # in the calculation
            pass

        # U-Zeros
        elif uncertainty_policy == uncertainty_policies[1]:
            data.replace({-1: 0}, inplace=True)

        # U-Ones
        elif uncertainty_policy == uncertainty_policies[2]:
            data.replace({-1: 1}, inplace=True)

        # U-SelfTrained
        elif uncertainty_policy == uncertainty_policies[3]:
            logger.warning(
                f"Using {uncertainty_policy} uncertainty policy, " +
                "make sure there are no uncertainty labels in the dataset.")
            return None

        # U-MultiClass
        elif uncertainty_policy == uncertainty_policies[4]:
            #data.replace({-1: 2}, inplace=True)

            one_hot_0 = [1., 0., 0.]
            one_hot_1 = [0., 1., 0.]
            one_hot_2 = [0., 0., 1.]

            data.loc[:, pathologies] = data.map(lambda x: one_hot_0 if x == 0 else one_hot_1 if x == 1 else one_hot_2).to_numpy()

        #data = data.head(10).copy()

        self.image_names = data.index.to_numpy()
        self.labels = np.array(
            data.loc[:, pathologies].values.tolist()
            ).reshape((-1, 15))
        self.transform = T.Compose([
                  T.Resize(resize_shape),
                  T.ToTensor(),
                  T.Normalize(mean=[0.5330], std=[0.0349])
              ])  # whiten with dataset mean and stdif transform)


    def __getitem__(self, index: int) -> Union[np.array, Tensor]:
        """ Returns image and label from given index.

        Args:
            index (int): Index of sample in dataset.

        Returns:
            np.array: Array of grayscale image.
            torch.Tensor: Tensor of labels.
        """
        if self.in_cloud:
            img_bytes = self.bucket.blob(
                self.image_names[index]).download_as_bytes()
            # .download_to_filename('tmp.jpg')
            img = Image.open(io.BytesIO(img_bytes)).convert('RGB')

        else:
            img = Image.open(self.image_names[index]).convert('RGB')
        img = self.transform(img)

        label = self.labels[index].astype(np.float32)
        return {"pixel_values": img, "labels": label}

    def __len__(self) -> int:
        """ Return length of dataset.

        Returns:
            int: length of dataset.
        """
        return len(self.image_names)

In [62]:
from transformers import Trainer
from torch import nn
from torch import masked_select


class MaskedLossTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")

        # compute custom loss (masking uncertanty in ignore approach)
        mask = labels > -1.
        criterion = nn.BCEWithLogitsLoss(device=model.device)
        loss = criterion(
            masked_select(logits.view(-1, self.model.config.num_labels), mask),
            masked_select(labels.view(-1, self.model.config.num_labels), mask))
        return (loss, outputs) if return_outputs else loss


class MultiOutputTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")

        loss_fct = nn.CrossEntropyLoss().to(model.device)

        loss_1 = loss_fct(logits[:, 0:3], labels[:, 0:3])
        loss_2 = loss_fct(logits[:, 3:6], labels[:, 3:6])
        loss_3 = loss_fct(logits[:, 6:9], labels[:, 6:9])
        loss_4 = loss_fct(logits[:, 9:12], labels[:, 9:12])
        loss_5 = loss_fct(logits[:, 12:], labels[:, 12:])

        loss = (
            loss_1 + loss_2 + loss_3 + loss_4 + loss_5
        )/5.0
        return (loss, outputs) if return_outputs else loss

In [82]:
import gc
import os
import argparse
from datetime import datetime

import torch
from torchmetrics.classification import (
    MultilabelAUROC,
    MultilabelF1Score,
    MultilabelAccuracy,
    MulticlassAUROC,
    MulticlassF1Score,
    MulticlassAccuracy
)

#from chexpert import CheXpertDataset
#from custom_trainer import MaskedLossTrainer, MultiOutputTrainer

from transformers import (
    ViTForImageClassification,
    TrainingArguments,
    Trainer
)

import wandb

import logging
log_fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
logging.basicConfig(level=logging.INFO, format=log_fmt)

gc.collect()

# Uncertainty policies on original paper
uncertainty_policies = ['U-Ignore',
                        'U-Zeros',
                        'U-Ones',
                        'U-SelfTrained',
                        'U-MultiClass']


device = 'cpu'
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    device = 'cuda'


def get_args():
    '''Parses args.'''

    parser = argparse.ArgumentParser("train_vit.py")
    parser.add_argument(
        "--epochs",
        "-e",
        required=False,
        type=int,
        default=5,
        help="Epochs of training"
    )
    parser.add_argument(
        "--learning_rate",
        "-l",
        required=False,
        type=float,
        default=4e-4,
        help="learning rate of training"
    )
    parser.add_argument(
        "--gradient_accumulation",
        "-g",
        required=False,
        type=int,
        default=64,
        help="Gradient accumulation steps"
    )
    parser.add_argument(
        "--batch_size",
        "-b",
        required=False,
        type=int,
        default=4,
        help="Batch size"
    )
    parser.add_argument(
        '--job_dir',
        '-j',
        required=False,
        type=str,
        default='.',
        help='Bucket to store saved model, include gs://')
    parser.add_argument(
        '--data_path',
        '-d',
        required=False,
        type=str,
        default=r"C:/Users/hurbl/OneDrive/Área de Trabalho/Loon Factory/repository/Chest-X-Ray-Pathology-Classifier/data/raw/",
        # default="gcs://chexpert_database_stanford/",
        help='Local or storage path to csv metadata file' 
    )
    parser.add_argument(
        '--uncertainty_policy',
        '-u',
        required=False,
        type=str,
        default=uncertainty_policies[-1],
        help='Uncertainty policy'
    )
    parser.add_argument(
        '--resize',
        '-r',
        required=False,
        type=tuple,
        default=(224, 224),
        help='Resize dimensions'
    )
    parser.add_argument(
        '--checkpoint',
        '-c',
        required=False,
        type=str,
        default='google/vit-base-patch16-224',
        help='checkpoint to load from hugging face hub'
    )
    args = parser.parse_args()
    return args


AUC = MultilabelAUROC(num_labels=5, average='macro', thresholds=None).to(device)
F1 = MultilabelF1Score(num_labels=5, average='macro').to(device)
ACC = MultilabelAccuracy(num_labels=5, average='macro').to(device)

multiclassAUC = MulticlassAUROC(num_classes=3, average='macro', thresholds=None).to(device)
multiclassF1 = MulticlassF1Score(num_classes=3, average='macro').to(device)
multiclassACC = MulticlassAccuracy(num_classes=3, average='macro').to(device)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    logits = torch.from_numpy(logits).to(device)
    labels = torch.from_numpy(labels).to(device).long()

    if labels.shape[1] == 15:
        label_1 = torch.argmax(labels[:, 0:3], dim=1).int()
        label_2 = torch.argmax(labels[:, 3:6], dim=1).int()
        label_3 = torch.argmax(labels[:, 6:9], dim=1).int()
        label_4 = torch.argmax(labels[:, 9:12], dim=1).int()
        label_5 = torch.argmax(labels[:, 12:], dim=1).int()


        auc = (
            multiclassAUC(logits[:, 0:3], label_1) +
            multiclassAUC(logits[:, 3:6], label_2) +
            multiclassAUC(logits[:, 6:9], label_3) +
            multiclassAUC(logits[:, 9:12], label_4) +
            multiclassAUC(logits[:, 12:], label_5)
        )/5.

        f1 = (
            multiclassF1(logits[:, 0:3], label_1) +
            multiclassF1(logits[:, 3:6], label_2) +
            multiclassF1(logits[:, 6:9], label_3) +
            multiclassF1(logits[:, 9:12], label_4) +
            multiclassF1(logits[:, 12:], label_5)
        )/5.

        acc = (
            multiclassACC(logits[:, 0:3], label_1) +
            multiclassACC(logits[:, 3:6], label_2) +
            multiclassACC(logits[:, 6:9], label_3) +
            multiclassACC(logits[:, 9:12], label_4) +
            multiclassACC(logits[:, 12:], label_5)
        )/5.

    else:
        auc = AUC(logits, labels)
        f1 = F1(logits, labels)
        acc = ACC(logits, labels)

    return {
        'auc_roc_mean': auc.cpu().mean(),
        'f1_mean': f1.cpu().mean(),
        'acc_mean': acc.cpu()
    }


def main(args):
    with wandb.init(project="chexpert-vit", job_type="train", config=args,
                    name=str(args.uncertainty_policy)+str(datetime.now().strftime("%d%m%Y_%H%M%S")),
                    tags=[
                        args.uncertainty_policy,
                        args.checkpoint]) as run:
        config = run.config

        train_dataset = CheXpertDataset(
            data_path=config['data_path'],
            uncertainty_policy=config['uncertainty_policy'],
            train=True,
            resize_shape=config['resize'])

        val_dataset = CheXpertDataset(
            data_path=config['data_path'],
            uncertainty_policy=config['uncertainty_policy'],
            train=False,
            resize_shape=config['resize'])

        num_labels = 15 if config['uncertainty_policy'] == 'U-MultiClass' else 5

        model = ViTForImageClassification.from_pretrained(
            config['checkpoint'], 
            problem_type="multi_label_classification",
            num_labels=num_labels,
            ignore_mismatched_sizes=True
        ).to(device)

        training_args = TrainingArguments(
                output_dir=f"./output/25092023/{config['checkpoint']}/{config['uncertainty_policy']}",
                report_to='wandb',  # Turn on Weights & Biases logging
                save_strategy='steps',
                save_steps=0.05,
                evaluation_strategy="epoch",
                logging_strategy='steps',
                logging_steps=1,
                optim='adamw_torch',
                num_train_epochs=config['epochs'],
                learning_rate=config['learning_rate'],
                lr_scheduler_type='linear',
                warmup_steps=1_000,
                max_grad_norm=1.0,
                per_device_train_batch_size=config['batch_size'],
                gradient_accumulation_steps=config['gradient_accumulation'],
                weight_decay=0.1,
                # gradient_checkpointing=True,
                auto_find_batch_size=False,
                fp16=True,
                dataloader_drop_last=True,
                #load_best_model_at_end=True,
                push_to_hub=True,
                hub_strategy='checkpoint',
                hub_private_repo=False,
                hub_model_id=f"lucascruz/CheXpert-ViT-{config['uncertainty_policy']}",
            )

        if config['uncertainty_policy'] == 'U-Ignore':
            trainer = MaskedLossTrainer(
                model=model,
                args=training_args,
                train_dataset=train_dataset,
                eval_dataset=val_dataset,
                compute_metrics=compute_metrics,
                )
        elif config['uncertainty_policy'] == 'U-MultiClass':
            trainer = MultiOutputTrainer(
                model=model,
                args=training_args,
                train_dataset=train_dataset,
                eval_dataset=val_dataset,
                compute_metrics=compute_metrics,
                )
        else:
            trainer = Trainer(
                    model=model,
                    args=training_args,
                    train_dataset=train_dataset,
                    eval_dataset=val_dataset,
                    compute_metrics=compute_metrics,
                )

        train_results = trainer.train()
        # trainer.save_model(f'{config["job_dir"]}/{config["uncertainty_policy"]}/model_output')

        trainer.log_metrics("train", train_results.metrics)
        trainer.save_metrics("train", train_results.metrics)
        trainer.save_state()

        metrics = trainer.evaluate()
        # some nice to haves:
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

  def wrap_fallback() -> None:
 20%|██        | 1/5 [19:03<1:16:13, 1143.46s/it]


In [83]:
if __name__ == "__main__":
    project_name = "chexpert-vit"
    os.environ["WANDB_PROJECT"] = project_name
    os.environ["WANDB_LOG_MODEL"] = "true"

    import sys
    sys.argv = ['']
    args = get_args()
    main(args)

2023-10-07 09:32:31,358 - __main__ - INFO - Local database found.
2023-10-07 09:32:31,742 - __main__ - INFO - Local database found.
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([15, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([15]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
c:\Users\hurbl\OneDrive\Área de Trabalho\Loon Factory\repository\Chest-X-Ray-Pathology-Classifier\notebooks\./output/25092023/google/vit-base-patch16-224/U-MultiClass is already a clone of https://huggingface.co/lucascruz/CheXpert-ViT-U-MultiClass. Make sure you pull the latest changes with `repo.git_pull()`.
                        

{'loss': 0.0396, 'learning_rate': 4.0000000000000003e-07, 'epoch': 1.0}



                                             


  def wrap_fallback() -> None:
[A[A[A                                     

 20%|██        | 1/5 [00:15<00:02,  1.71it/s]

[A[A
[A

{'eval_loss': 1.3038330078125, 'eval_auc_roc_mean': 0.31190475821495056, 'eval_f1_mean': 0.2030014544725418, 'eval_acc_mean': 0.2825396955013275, 'eval_runtime': 0.5405, 'eval_samples_per_second': 18.501, 'eval_steps_per_second': 3.7, 'epoch': 1.0}


                                             


[A[A[A                                     

 40%|████      | 2/5 [00:15<00:27,  9.08s/it]

[A[A
[A

{'loss': 0.039, 'learning_rate': 8.000000000000001e-07, 'epoch': 2.0}



                                             


  def wrap_fallback() -> None:
[A[A[A                                     

 40%|████      | 2/5 [00:16<00:27,  9.08s/it]

[A[A
[A

{'eval_loss': 1.299218773841858, 'eval_auc_roc_mean': 0.31190475821495056, 'eval_f1_mean': 0.2030014544725418, 'eval_acc_mean': 0.2825396955013275, 'eval_runtime': 0.3271, 'eval_samples_per_second': 30.572, 'eval_steps_per_second': 6.114, 'epoch': 2.0}


                                             


[A[A[A                                     

 60%|██████    | 3/5 [00:17<00:11,  5.69s/it]

[A[A
[A

{'loss': 0.0399, 'learning_rate': 1.2000000000000002e-06, 'epoch': 3.0}



                                             


  def wrap_fallback() -> None:
[A[A[A                                     

 60%|██████    | 3/5 [01:01<00:11,  5.69s/it]

[A[A
[A

{'eval_loss': 1.2904175519943237, 'eval_auc_roc_mean': 0.3063492178916931, 'eval_f1_mean': 0.2030014544725418, 'eval_acc_mean': 0.2825396955013275, 'eval_runtime': 0.5829, 'eval_samples_per_second': 17.156, 'eval_steps_per_second': 3.431, 'epoch': 3.0}


                                             


[A[A[A                                     

 80%|████████  | 4/5 [01:01<00:21, 21.09s/it]

[A[A
[A

{'loss': 0.0386, 'learning_rate': 1.6000000000000001e-06, 'epoch': 4.0}



                                             


  def wrap_fallback() -> None:
[A[A[A                                     

 80%|████████  | 4/5 [01:03<00:21, 21.09s/it]

[A[A
[A

{'eval_loss': 1.2769348621368408, 'eval_auc_roc_mean': 0.3063492178916931, 'eval_f1_mean': 0.2030014544725418, 'eval_acc_mean': 0.2825396955013275, 'eval_runtime': 0.3384, 'eval_samples_per_second': 29.552, 'eval_steps_per_second': 5.91, 'epoch': 4.0}


                                             


[A[A[A                                     

100%|██████████| 5/5 [01:03<00:00, 14.13s/it]

[A[A
[A

{'loss': 0.0393, 'learning_rate': 2.0000000000000003e-06, 'epoch': 5.0}



                                             


  def wrap_fallback() -> None:
[A[A[A                                     

100%|██████████| 5/5 [01:05<00:00, 14.13s/it]

[A[A
                                             


[A[A[A                                     

100%|██████████| 5/5 [01:05<00:00, 14.13s/it]

[A[A
[A

{'eval_loss': 1.2586853504180908, 'eval_auc_roc_mean': 0.3063492178916931, 'eval_f1_mean': 0.17229436337947845, 'eval_acc_mean': 0.21587303280830383, 'eval_runtime': 0.3887, 'eval_samples_per_second': 25.73, 'eval_steps_per_second': 5.146, 'epoch': 5.0}
{'train_runtime': 65.303, 'train_samples_per_second': 0.766, 'train_steps_per_second': 0.077, 'train_loss': 0.03927242383360863, 'epoch': 5.0}


c:\Users\hurbl\OneDrive\Área de Trabalho\Loon Factory\repository\Chest-X-Ray-Pathology-Classifier\notebooks\./output/25092023/google/vit-base-patch16-224/U-MultiClass is already a clone of https://huggingface.co/lucascruz/CheXpert-ViT-U-MultiClass. Make sure you pull the latest changes with `repo.git_pull()`.


0,1
eval/acc_mean,████▁
eval/auc_roc_mean,██▁▁▁
eval/f1_mean,████▁
eval/loss,█▇▆▄▁
eval/runtime,▇▁█▁▃
eval/samples_per_second,▂█▁▇▅
eval/steps_per_second,▂█▁▇▅
train/epoch,▁▁▃▃▅▅▆▆███
train/global_step,▁▁▃▃▅▅▆▆███
train/learning_rate,▁▃▅▆█

0,1
eval/acc_mean,0.21587
eval/auc_roc_mean,0.30635
eval/f1_mean,0.17229
eval/loss,1.25869
eval/runtime,0.3887
eval/samples_per_second,25.73
eval/steps_per_second,5.146
train/epoch,5.0
train/global_step,5.0
train/learning_rate,0.0


KeyboardInterrupt: 

In [51]:
data_path = r"C:/Users/hurbl/OneDrive/Área de Trabalho/Loon Factory/repository/Chest-X-Ray-Pathology-Classifier/data/raw/"

train_dataset = CheXpertDataset(
    data_path=data_path,
    uncertainty_policy="U-MultiClass",
    train=True,
    resize_shape=(224, 224))


val_dataset = CheXpertDataset(
    data_path=data_path,
    uncertainty_policy="U-MultiClass",
    train=False,
    resize_shape=(224, 224))

In [31]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224", 
    problem_type="multi_label_classification",
    num_labels=5,
    ignore_mismatched_sizes=True
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([5, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([5]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [32]:
import evaluate
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

METRICS = evaluate.combine(["accuracy", "f1", "precision", "recall"])

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return METRICS.compute(predictions=predictions, references=labels)


args = TrainingArguments(
    f"vit_xray",
    remove_unused_columns=False,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    optim='adamw_torch',
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy"
)

trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)



In [33]:
train_results = trainer.train()

#trainer.save_model()
#trainer.log_metrics("train", train_results.metrics)
#trainer.save_metrics("train", train_results.metrics)
#trainer.save_state()


[A
  0%|          | 0/10473 [06:14<?, ?it/s]
  0%|          | 0/10473 [04:37<?, ?it/s]


In [None]:
metrics = trainer.evaluate()
# some nice to haves:
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

In [1]:
import torch
torch.cuda.empty_cache()
torch.cuda.is_available()

True

In [2]:
torch.zeros(1).cuda()

tensor([0.], device='cuda:0')

In [66]:
from torch import nn
import torch
from torch import masked_select

labels = torch.ones([10, 5], dtype=torch.float32)  # 64 classes, batch size = 10
logits = torch.full([10, 5], 0.5)  # A prediction (logit)

labels[2,3] = -1.
labels[3,3] = -1.
labels[6,1] = -1.
labels[8,4] = -1.

mask = labels > -1.

masked_labels = masked_select(labels.view(-1, 5), mask)
masked_logits = masked_select(logits.view(-1, 5), mask)

loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(masked_logits, masked_labels)

In [67]:
loss

tensor(0.5541)

In [65]:
logits, masked_logits

(tensor([[0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000]]),
 tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000]))

In [85]:
logits.view(-1)

tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000])