In [1]:
from torchvision.transforms import ColorJitter
from transformers import SegformerImageProcessor#,SegformerForSemanticSegmentation
import numpy as np
from PIL import Image
from  transformers.models.segformer.modeling_segformer import SegformerPreTrainedModel,SegformerModel,SegformerDecodeHead
from  transformers.modeling_outputs import BaseModelOutput, SequenceClassifierOutput

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# choose your loss https://github.com/shruti-jadon/Semantic-Segmentation-Loss-Functions
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        probs = torch.sigmoid(inputs)
        probs = probs.clamp(min=1e-4, max=1-1e-4)

        pt = (probs * targets) + ((1 - probs) * (1 - targets))
        focal_weight = (1 - pt).pow(self.gamma)

        alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        bce_loss = F.binary_cross_entropy(probs, targets, reduction='none')

        loss = alpha_t * focal_weight * bce_loss

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss
from torch.nn import MSELoss

class CombinateLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean', size_average=None, reduce=None, reduction_mse='mean', coef=0.05):
        super(CombinateLoss, self).__init__()
        self.focal = FocalLoss(alpha, gamma, reduction)
        self.mse = MSELoss(size_average=size_average, reduce=reduce, reduction=reduction_mse)
        self.coef = coef

    def forward(self, inputs, targets):
        return self.focal(inputs, targets)*(1-self.coef) + self.mse(inputs, targets) * self.coef


In [3]:
class SegformerForСraft(SegformerPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.segformer = SegformerModel(config)
        self.decode_head = SegformerDecodeHead(config)

        self.init_weights()

    def forward(
        self,
        pixel_values,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )

        outputs = self.segformer(
            pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=True,  # we need the intermediate hidden states
            return_dict=return_dict,
        )

        encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]

        logits = self.decode_head(encoder_hidden_states)

        loss = None
        if labels is not None:
            if self.config.num_labels == 1:
                raise ValueError("The number of labels should be greater than one")
            else:
                # upsample logits to the images' original size
                # print('l',logits.shape)
                upsampled_logits = nn.functional.interpolate(
                    logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
                )
                # print('u',upsampled_logits.shape)

                loss_fct = CombinateLoss()# TODO add params 
                loss = loss_fct(upsampled_logits, labels)

        if not return_dict:
            if output_hidden_states:
                output = (logits,) + outputs[1:]
            else:
                output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=torch.clip(upsampled_logits,0,1), # interpolate loggits
            hidden_states=outputs.hidden_states if output_hidden_states else None,
            attentions=outputs.attentions,
        )

In [4]:
feature_extractor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b1-finetuned-ade-512-512")# choose differnt predtrain
model = SegformerForСraft.from_pretrained( 
    "nvidia/segformer-b1-finetuned-ade-512-512", # choose differnt predtrain
    num_labels=2,
    ignore_mismatched_sizes=True
)


  return func(*args, **kwargs)
Some weights of SegformerForСraft were not initialized from the model checkpoint at nvidia/segformer-b1-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.weight: found shape torch.Size([150, 256, 1, 1]) in the checkpoint and torch.Size([2, 256, 1, 1]) in the model instantiated
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([2]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
type(feature_extractor)

transformers.models.segformer.image_processing_segformer.SegformerImageProcessor

In [6]:
# model.decode_head

In [7]:
np.load('test_data/test_heatmap.npy').max()

1.2076961

In [8]:
!ls test_data

test_heatmap.npy  test_img.png


In [9]:
import torch
from torch.utils.data import Dataset


class ExampleCRAFTDataset(Dataset):
    def __init__(self,  feature_extractor):
       
        self.feature_extractor = feature_extractor

    def __len__(self):
        return 10

    def __getitem__(self, idx):
        image = Image.open('test_data/test_img.png').convert("RGB")
        
      
        mask = np.load('test_data/test_heatmap.npy')

        

        encoding = self.feature_extractor(
            image,
            size=512,
            do_resize=True,
            do_normalize=True,
            return_tensors="pt"
        )

        pixel_values = encoding['pixel_values'].squeeze(0)  # (3, 512, 512)


        mask_pil_text = Image.fromarray(mask[:,:,0])  # 0-й канал (Text)
        mask_pil_link = Image.fromarray(mask[:,:,1])  # 1-й канал (Link)
        
        mask_pil_text = mask_pil_text.resize((512, 512), resample=Image.BICUBIC)
        mask_pil_link = mask_pil_link.resize((512, 512), resample=Image.BICUBIC)

        mask_resized_text = np.array(mask_pil_text, dtype=np.float32)
        mask_resized_link = np.array(mask_pil_link, dtype=np.float32)

        mask_2ch = np.stack([mask_resized_text, mask_resized_link], axis=0)

        labels = torch.from_numpy(np.clip(mask_2ch, 0, 1))  # shape (2, 512, 512)

        return {
            "pixel_values": pixel_values,  # (3, 512, 512)
            "labels": labels               # (2, 512, 512)
        }


In [10]:
test_dataset = ExampleCRAFTDataset(feature_extractor)

In [11]:
test_dataset[0]['labels'].shape

torch.Size([2, 512, 512])

In [12]:
test_dataset[0]['labels'].type()

'torch.FloatTensor'

In [13]:
def craft_data_collator(batch):
    pixel_values = torch.stack([item["pixel_values"] for item in batch])  
    labels = torch.stack([item["labels"] for item in batch])  
    # pixel_values: (B, 3, 512, 512)
    # labels:       (B, 2, 512, 512)

    return {
        "pixel_values": pixel_values,
        "labels": labels
    }


In [14]:
model.config.id2label = {0: "text", 1: "link"}
model.config.label2id = {"text": 0, "link": 1}

In [15]:
test_batch = {
        "pixel_values":  torch.stack([test_dataset[0]["pixel_values"]]),
        "labels":torch.stack([test_dataset[0]["labels"]]),
    }

In [16]:
with torch.no_grad():
    outputs = model(**test_batch)


In [17]:
outputs['loss']

tensor(0.0142)

In [18]:
outputs['logits'].shape

torch.Size([1, 2, 512, 512])

In [19]:
from transformers import TrainingArguments, Trainer
import craft_utils
training_args = TrainingArguments(
    output_dir="./segformer_craft",
    num_train_epochs=5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    evaluation_strategy="epoch",  # Будем оцениваться каждый эпох
    save_strategy="epoch",        # Сохранять чекпоинты раз в эпоху
    logging_steps=50,             # Каждые 50 итераций в лог
    learning_rate=5e-5,
    weight_decay=0.01,
    # и т.д.
)

    

    # coordinate adjustment
def compute_metrics(eval_preds):
    logits, labels =eval_preds
    # add metrics 
    # print(logits.shape)#: (batch_size, 2, H, W)
    # print(labels.shape)#: (batch_size, 2, H, W)  (если мы так храним)
    return {}




In [20]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=test_dataset,
    eval_dataset=test_dataset,
    data_collator=craft_data_collator,
    compute_metrics=compute_metrics,  # опционально
)


In [21]:
trainer.train()

Epoch,Training Loss,Validation Loss
1,No log,0.014263
2,No log,0.013152
3,No log,0.012963
4,No log,0.012984
5,No log,0.012849


TrainOutput(global_step=25, training_loss=0.014271169900894165, metrics={'train_runtime': 5.0048, 'train_samples_per_second': 9.99, 'train_steps_per_second': 4.995, 'total_flos': 3226988917555200.0, 'train_loss': 0.014271169900894165, 'epoch': 5.0})