Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multimodal deep learning example throws error: Can't get attribute 'DocVec[TextDoc]' on <module 'docarray.array.any_array' #1614

Closed
Robbie-Palmer opened this issue Jun 2, 2023 · 4 comments · Fixed by #1615

Comments

@Robbie-Palmer
Copy link

The documentation contains a "How-to" guide on training a multimodal CLIP-esque model
There are a number of small bugs such as incorrectly named classes or parameters, which are easily worked through, resulting in a script roughly like this:

import itertools
from pathlib import Path
from typing import Optional

import pandas as pd
import torch
import torchvision
from docarray import BaseDoc, DocList, DocVec
from docarray.data import MultiModalDataset
from docarray.documents import TextDoc as BaseText
from docarray.typing import TorchTensor, ImageUrl
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer, DistilBertModel

DEVICE = "cuda:0"


class Tokens(BaseDoc):
    input_ids: TorchTensor[48]
    attention_mask: TorchTensor


class TextDoc(BaseText):
    tokens: Optional[Tokens]


class ImageDoc(BaseDoc):
    url: Optional[ImageUrl]
    tensor: Optional[TorchTensor]
    embedding: Optional[TorchTensor]


class PairTextImage(BaseDoc):
    text: TextDoc
    image: ImageDoc


class VisionPreprocess:
    def __init__(self):
        self.transform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Resize(232),
            torchvision.transforms.RandomCrop(224),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

    def __call__(self, image: ImageDoc) -> None:
        image.tensor = self.transform(image.url.load())


class TextPreprocess:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")

    def __call__(self, text: TextDoc) -> None:
        assert isinstance(text, TextDoc)
        text.tokens = Tokens(**self.tokenizer(text.text, padding="max_length", truncation=True, max_length=48))


class TextEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")

    def forward(self, texts: DocList[TextDoc]) -> TorchTensor:
        last_hidden_state = self.bert(
            input_ids=texts.tokens.input_ids, attention_mask=texts.tokens.attention_mask
        ).last_hidden_state

        return self._mean_pool(last_hidden_state, texts.tokens.attention_mask)

    @staticmethod
    def _mean_pool(last_hidden_state: TorchTensor, attention_mask: TorchTensor) -> TorchTensor:
        masked_output = last_hidden_state * attention_mask.unsqueeze(-1)
        return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)


class VisionEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = torchvision.models.resnet18(pretrained=True)
        self.linear = nn.LazyLinear(out_features=768)

    def forward(self, images: DocList[ImageDoc]) -> TorchTensor:
        x = self.backbone(images.tensor)
        return self.linear(x)


def cosine_sim(x_mat: TorchTensor, y_mat: TorchTensor) -> TorchTensor:
    a_n, b_n = x_mat.norm(dim=1)[:, None], y_mat.norm(dim=1)[:, None]
    a_norm = x_mat / torch.clamp(a_n, min=1e-7)
    b_norm = y_mat / torch.clamp(b_n, min=1e-7)
    return torch.mm(a_norm, b_norm.transpose(0, 1)).squeeze()


def clip_loss(image: DocList[ImageDoc], text: DocList[TextDoc]) -> TorchTensor:
    sims = cosine_sim(image.embedding, text.embedding)
    return torch.norm(sims - torch.eye(sims.shape[0], device=DEVICE))


def load_flickr8k_doclist(dataset_dir: Path = Path('.'), n: Optional[int] = None) -> DocList[PairTextImage]:
    df = pd.read_csv(dataset_dir / 'captions.txt', nrows=n)
    doc_list = DocList[PairTextImage](
        PairTextImage(text=TextDoc(text=sample.caption), image=ImageDoc(url=f"{dataset_dir}/Images/{sample.image}"))
        for sample in df.itertuples())
    return doc_list


flickr8k_doclist = load_flickr8k_doclist(Path.home() / 'datasets/flickr8k')
dataset = MultiModalDataset[PairTextImage](docs=flickr8k_doclist,
                                           preprocessing=dict(image=VisionPreprocess(), text=TextPreprocess()))
loader = DataLoader(dataset,
                    batch_size=128,
                    collate_fn=dataset.collate_fn,
                    shuffle=True,
                    num_workers=1)

vision_encoder = VisionEncoder().to(DEVICE)
text_encoder = TextEncoder().to(DEVICE)
optim = torch.optim.Adam(itertools.chain(vision_encoder.parameters(), text_encoder.parameters()), lr=3e-4)

num_epoch = 1

with torch.autocast(device_type="cuda", dtype=torch.float16):
    for epoch in range(num_epoch):
        for i, batch in tqdm(enumerate(loader), total=len(loader), desc=f"Epoch {epoch}"):
            batch.to(DEVICE)

            optim.zero_grad()
            batch.image.embedding = vision_encoder(batch.image)
            batch.text.embedding = text_encoder(batch.text)
            loss = clip_loss(batch.image, batch.text)
            if i % 30 == 0:
                print(f"{i + epoch} steps , loss : {loss}")
            loss.backward()
            optim.step()

But PyTorch throws an error that seems linked to deeper in DocArray on trying to load the data

Traceback (most recent call last):
  File ".../multimodal_embedding.py", line 127, in <module>
    for i, batch in tqdm(enumerate(loader), total=len(loader), desc=f"Epoch {epoch}"):
  File ".../python3.10/site-packages/tqdm/std.py", line 1178, in __iter__
    for obj in iterable:
  File ".../python3.10/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
    data = self._next_data()
  File ".../python3.10/site-packages/torch/utils/data/dataloader.py", line 1328, in _next_data
    idx, data = self._get_data()
  File ".../python3.10/site-packages/torch/utils/data/dataloader.py", line 1294, in _get_data
    success, data = self._try_get_data()
  File ".../python3.10/site-packages/torch/utils/data/dataloader.py", line 1132, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
  File ".../python3.10/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
AttributeError: Can't get attribute 'DocVec[TextDoc]' on <module 'docarray.array.any_array' from '.../python3.10/site-packages/docarray/array/any_array.py'>

This is a similar error to #1480 but it can't be worked around by initializing DocVec[PairTextImage] as the issue is within a PyTorch process, trying to unpickle from a queue
This may or may not be handled by #1330 when it is resolved

@JoanFM
Copy link
Member

JoanFM commented Jun 2, 2023

Hey @Robbie-Palmer , we are going to look into it ASAP

@JoanFM
Copy link
Member

JoanFM commented Jun 2, 2023

Does it work if you add this before the classes are defined?

DocVec[TextDoc]

If so, I will change the Documentation and mention that it should not be needed when #1330 is done

@Robbie-Palmer
Copy link
Author

It's running now I've added these four:

DocVec[Tokens]
DocVec[TextDoc]
DocVec[ImageDoc]
DocVec[PairTextImage]

So the script is looking like:

from typing import Optional

from docarray import BaseDoc, DocVec
from docarray.documents import TextDoc as BaseText
from docarray.typing import TorchTensor, ImageUrl

DEVICE = "cuda:0"


class Tokens(BaseDoc):
    input_ids: TorchTensor[48]
    attention_mask: TorchTensor


class TextDoc(BaseText):
    tokens: Optional[Tokens]


class ImageDoc(BaseDoc):
    url: Optional[ImageUrl]
    tensor: Optional[TorchTensor]
    embedding: Optional[TorchTensor]


class PairTextImage(BaseDoc):
    text: TextDoc
    image: ImageDoc


DocVec[Tokens]
DocVec[TextDoc]
DocVec[ImageDoc]
DocVec[PairTextImage]

import itertools
from pathlib import Path

import pandas as pd
import torch
import torchvision
from docarray import DocList
from docarray.data import MultiModalDataset
from docarray.typing import TorchTensor
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer, DistilBertModel
...

@JoanFM
Copy link
Member

JoanFM commented Jun 2, 2023

PR to add this to Documentation is created in #1615, underlying issue to be resolved in #1330

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Archived in project
Development

Successfully merging a pull request may close this issue.

2 participants