Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export to ONNX supported? #19

Closed
dnth opened this issue Apr 19, 2023 · 25 comments
Closed

Export to ONNX supported? #19

dnth opened this issue Apr 19, 2023 · 25 comments
Labels
enhancement New feature or request

Comments

@dnth
Copy link

dnth commented Apr 19, 2023

I'm trying to export the model into .onnx.

Here's my code:

import torch
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to('cuda:0')
model.eval()

# Generate some input data
input_data = torch.randn(1, 3, 224, 224).to('cuda:0')

# Pass the input data through the model
output = model(input_data)

torch.onnx.export(model, input_data, 'model.onnx')

I got an error

============= Diagnostic Run torch.onnx.export version 2.0.0+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-17-2f8453f4374c>](https://localhost:8080/#) in <cell line: 1>()
----> 1 torch.onnx.export(model, input_data, 'model.onnx')

13 frames
[~/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/models/vision_transformer.py](https://localhost:8080/#) in prepare_tokens_with_masks(self, x, masks)
    193         x = self.patch_embed(x)
    194         if masks is not None:
--> 195             x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
    196 
    197         x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Did I miss anything?

@oylz
Copy link

oylz commented Apr 19, 2023

I have export successfully on cpu

1.the first step:no using MemEffAttention

change this line

  • from
block_fn=partial(Block, attn_class=MemEffAttention),
  • to
block_fn=partial(Block, attn_class=Attention),
  • remember import it
from dinov2.layers.attention import Attention**

2.the second

  • change the device from cuda:0 to cpu in your code
import torch
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to('cpu')
model.eval()

# Generate some input data
input_data = torch.randn(1, 3, 224, 224).to('cpu')

# Pass the input data through the model
output = model(input_data)

torch.onnx.export(model, input_data, 'model.onnx')

@dnth
Copy link
Author

dnth commented Apr 19, 2023

Thank you @oylz . I did that and now end up with another error

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[4], line 11
      8 # Pass the input data through the model
      9 output = model(input_data)
---> 11 torch.onnx.export(model, input_data, 'model.onnx')

File ~/anaconda3/envs/dinov2/lib/python3.9/site-packages/torch/onnx/utils.py:506, in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions)
    188 @_beartype.beartype
    189 def export(
    190     model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction],
   (...)
    206     export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]] = False,
    207 ) -> None:
    208     r"""Exports a model into ONNX format.
    209 
    210     If ``model`` is not a :class:`torch.jit.ScriptModule` nor a
   (...)
    503             All errors are subclasses of :class:`errors.OnnxExporterError`.
    504     """
--> 506     _export(
    507         model,
    508         args,
    509         f,
    510         export_params,
    511         verbose,
    512         training,
    513         input_names,
    514         output_names,
    515         operator_export_type=operator_export_type,
    516         opset_version=opset_version,
    517         do_constant_folding=do_constant_folding,
    518         dynamic_axes=dynamic_axes,
    519         keep_initializers_as_inputs=keep_initializers_as_inputs,
    520         custom_opsets=custom_opsets,
    521         export_modules_as_functions=export_modules_as_functions,
    522     )

File ~/anaconda3/envs/dinov2/lib/python3.9/site-packages/torch/onnx/utils.py:1548, in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, onnx_shape_inference, export_modules_as_functions)
   1545     dynamic_axes = {}
   1546 _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
-> 1548 graph, params_dict, torch_out = _model_to_graph(
   1549     model,
   1550     args,
   1551     verbose,
   1552     input_names,
   1553     output_names,
   1554     operator_export_type,
   1555     val_do_constant_folding,
   1556     fixed_batch_size=fixed_batch_size,
   1557     training=training,
   1558     dynamic_axes=dynamic_axes,
   1559 )
   1561 # TODO: Don't allocate a in-memory string for the protobuf
   1562 defer_weight_export = (
   1563     export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE
   1564 )

File ~/anaconda3/envs/dinov2/lib/python3.9/site-packages/torch/onnx/utils.py:1113, in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
   1110     args = (args,)
   1112 model = _pre_trace_quant_model(model, args)
-> 1113 graph, params, torch_out, module = _create_jit_graph(model, args)
   1114 params_dict = _get_named_param_dict(graph, params)
   1116 try:

File ~/anaconda3/envs/dinov2/lib/python3.9/site-packages/torch/onnx/utils.py:989, in _create_jit_graph(model, args)
    984     graph = _C._propagate_and_assign_input_shapes(
    985         graph, flattened_args, param_count_list, False, False
    986     )
    987     return graph, params, torch_out, None
--> 989 graph, torch_out = _trace_and_get_graph_from_model(model, args)
    990 _C._jit_pass_onnx_lint(graph)
    991 state_dict = torch.jit._unique_state_dict(model)

File ~/anaconda3/envs/dinov2/lib/python3.9/site-packages/torch/onnx/utils.py:893, in _trace_and_get_graph_from_model(model, args)
    891 prev_autocast_cache_enabled = torch.is_autocast_cache_enabled()
    892 torch.set_autocast_cache_enabled(False)
--> 893 trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
    894     model,
    895     args,
    896     strict=False,
    897     _force_outplace=False,
    898     _return_inputs_states=True,
    899 )
    900 torch.set_autocast_cache_enabled(prev_autocast_cache_enabled)
    902 warn_on_static_input_change(inputs_states)

File ~/anaconda3/envs/dinov2/lib/python3.9/site-packages/torch/jit/_trace.py:1268, in _get_trace_graph(f, args, kwargs, strict, _force_outplace, return_inputs, _return_inputs_states)
   1266 if not isinstance(args, tuple):
   1267     args = (args,)
-> 1268 outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
   1269 return outs

File ~/anaconda3/envs/dinov2/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/dinov2/lib/python3.9/site-packages/torch/jit/_trace.py:127, in ONNXTracedModule.forward(self, *args)
    124     else:
    125         return tuple(out_vars)
--> 127 graph, out = torch._C._create_graph_by_tracing(
    128     wrapper,
    129     in_vars + module_state,
    130     _create_interpreter_name_lookup_fn(),
    131     self.strict,
    132     self._force_outplace,
    133 )
    135 if self._return_inputs:
    136     return graph, outs[0], ret_inputs[0]

File ~/anaconda3/envs/dinov2/lib/python3.9/site-packages/torch/jit/_trace.py:118, in ONNXTracedModule.forward.<locals>.wrapper(*args)
    116 if self._return_inputs_states:
    117     inputs_states.append(_unflatten(in_args, in_desc))
--> 118 outs.append(self.inner(*trace_inputs))
    119 if self._return_inputs_states:
    120     inputs_states[0] = (inputs_states[0], trace_inputs)

File ~/anaconda3/envs/dinov2/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/dinov2/lib/python3.9/site-packages/torch/nn/modules/module.py:1488, in Module._slow_forward(self, *input, **kwargs)
   1486         recording_scopes = False
   1487 try:
-> 1488     result = self.forward(*input, **kwargs)
   1489 finally:
   1490     if recording_scopes:

File ~/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/models/vision_transformer.py:292, in DinoVisionTransformer.forward(self, is_training, *args, **kwargs)
    291 def forward(self, *args, is_training=False, **kwargs):
--> 292     ret = self.forward_features(*args, **kwargs)
    293     if is_training:
    294         return ret

File ~/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/models/vision_transformer.py:226, in DinoVisionTransformer.forward_features(self, x, masks)
    223 if isinstance(x, list):
    224     return self.forward_features_list(x, masks)
--> 226 x = self.prepare_tokens_with_masks(x, masks)
    228 for blk in self.blocks:
    229     x = blk(x)

File ~/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/models/vision_transformer.py:199, in DinoVisionTransformer.prepare_tokens_with_masks(self, x, masks)
    196     x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
    198 x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
--> 199 x = x + self.interpolate_pos_encoding(x, w, h)
    201 return x

File ~/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/models/vision_transformer.py:182, in DinoVisionTransformer.interpolate_pos_encoding(self, x, w, h)
    178 # we add a small number to avoid floating point error in the interpolation
    179 # see discussion at https://github.com/facebookresearch/dino/issues/8
    180 w0, h0 = w0 + 0.1, h0 + 0.1
--> 182 patch_pos_embed = nn.functional.interpolate(
    183     patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
    184     scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
    185     mode="bicubic",
    186 )
    188 assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
    189 patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)

File ~/anaconda3/envs/dinov2/lib/python3.9/site-packages/torch/nn/functional.py:3967, in interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
   3965     if antialias:
   3966         return torch._C._nn._upsample_bicubic2d_aa(input, output_size, align_corners, scale_factors)
-> 3967     return torch._C._nn.upsample_bicubic2d(input, output_size, align_corners, scale_factors)
   3969 if input.dim() == 3 and mode == "bilinear":
   3970     raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input")

TypeError: upsample_bicubic2d() received an invalid combination of arguments - got (Tensor, NoneType, bool, tuple), but expected one of:
 * (Tensor input, tuple of ints output_size, bool align_corners, tuple of floats scale_factors)
      didn't match because some of the arguments have invalid types: (Tensor, !NoneType!, bool, !tuple of (Tensor, Tensor)!)
 * (Tensor input, tuple of ints output_size, bool align_corners, float scales_h, float scales_w, *, Tensor out)

onnx version - 1.13.1

@oylz
Copy link

oylz commented Apr 19, 2023

@dnth
you can do like this continue:
replace this block

  • from
patch_pos_embed = nn.functional.interpolate(
    patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 
    scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
    mode="bicubic",
)
  • to
aa = patch_pos_embed.reshape(
        1,  
        int(math.sqrt(N)), 
        int(math.sqrt(N)), 
        dim 
    ).permute(0, 3, 1, 2)
bb = (w0 / math.sqrt(N), h0 / math.sqrt(N))
cc = bb
if True and isinstance(bb[0], torch.Tensor):
    cc = (bb[0].item(), bb[1].item()) # **** 1.make cc from tuple(tensor[float], tensor(float)) to tuple(float, float)
patch_pos_embed = nn.functional.interpolate(
    aa,
    scale_factor=cc,
    mode="bilinear" #"bicubic", # **** 2.if this not change, will cause runtime exception when using onnx model
)

@dnth
Copy link
Author

dnth commented Apr 19, 2023

@oylz thank you again! The export to onnx now does not produce any error!

@dbickson
Copy link

HI @oylz once we build an onnx model, we see it has 2 input channels. The second one is named masks. What should be given there are input? Is there a way to extract feature vectors without the masks?

@oylz
Copy link

oylz commented Apr 19, 2023

@dbickson
you can write a wrapper class, and then export it

import torch
from PIL import Image
import torchvision.transforms as T
import hubconf

class xyz_model(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model 

    def forward(self, tensor):
        ff = self.model(tensor)
        return ff

model = hubconf.dinov2_vitl14()
mm = xyz_model(model).to('cpu')
mm.eval()
....do export

@patricklabatut patricklabatut added the enhancement New feature or request label Apr 19, 2023
@patricklabatut
Copy link
Contributor

Thanks @oylz for helping out!

@dnth We have never tried to export the models to ONNX. If there is enough feedback / signal, we could considered providing model weights in this format as well.

@dbickson
Copy link

dbickson commented Apr 20, 2023

Hi @patricklabatut we were able to successfully export the model to onnx and give it a try ( @dnth @amiralush and me are working together). Results look really good compared to other network embedding. The final trick was to reset the masks otherwise onnx expects two inputs.

Here is an example images using fastdup for clustering on the Oxford-IIIT Pet Dataset.

image

image

The results are really good since the model understands the animal breed characteristics.
It will be super useful if you could export the model to onnx and share on the release page of this repo, it can encourage wide adoption of the model.

Many thanks for your great project!!

@oylz
Copy link

oylz commented Apr 20, 2023

@dbickson show me your export code,maybe I can help you.

@dnth
Copy link
Author

dnth commented Apr 20, 2023

@oylz here's the code we used

import torch 
from PIL import Image 
import torchvision.transforms as T 
import hubconf 
  
 class xyz_model(torch.nn.Module): 
     def __init__(selfmodel): 
         super().__init__() 
         self.model = model  
  
     def forward(selftensor): 
         ff = self.model(tensor) 
         return ff 
      
 ### Change your model here .dinov2_vits14() / .dinov2_vitb14() / .dinov2_vitl14() /.dinov2_vitg14()  
 model = hubconf.dinov2_vits14() 
 mm = xyz_model(model).to('cpu') 
  
 mm.eval() 
 input_data = torch.randn(13224224).to('cpu') 
 output = mm(input_data) 
  
 torch.onnx.export(mminput_data'model.onnx'input_names = ['input'])

@oylz
Copy link

oylz commented Apr 20, 2023

@dnth sorry, but I see the onnx with only one input. the following is my step:

1.to onnx

  • 420.py
import torch 
from PIL import Image 
import torchvision.transforms as T 
import hubconf 
  
class xyz_model(torch.nn.Module): 
    def __init__(self, model): 
        super().__init__() 
        self.model = model  
 
    def forward(self, tensor): 
        ff = self.model(tensor) 
        return ff  
    
### Change your model here .dinov2_vits14() / .dinov2_vitb14() / .dinov2_vitl14() /.dinov2_vitg14()  
model = hubconf.dinov2_vits14() 
mm = xyz_model(model).to('cpu') 
 
mm.eval() 
input_data = torch.randn(1, 3, 224, 224).to('cpu') 
output = mm(input_data) 
 
torch.onnx.export(mm, input_data, 'model.onnx', input_names = ['input'])
print("ok")
  • output
/data/dinov2/dinov2/models/vision_transformer.py:170: TracerWarning: Converting a tensor to a Python 
  if npatch == N and w == h:
/data/dinov2/dinov2/models/vision_transformer.py:184: TracerWarning: Converting a tensor to a Python
  int(math.sqrt(N)),
/data/dinov2/dinov2/models/vision_transformer.py:185: TracerWarning: Converting a tensor to a Python
  int(math.sqrt(N)),
/data/dinov2/dinov2/models/vision_transformer.py:194: TracerWarning: Converting a tensor to a Python
  bb = (w0 / math.sqrt(N), h0 / math.sqrt(N))
/data/dinov2/dinov2/models/vision_transformer.py:197: TracerWarning: Converting a tensor to a Python
  cc = (bb[0].item(), bb[1].item())
============= Diagnostic Run torch.onnx.export version 2.0.0+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

ok

2.show onnx model inputs and outputs

  • test_onnx.py
import onnxruntime
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image

def print_input_output(sess):
    iis = sess.get_inputs()
    for ii in iis:
        print(ii.name, ii.shape)
    print("-------------------------------")
    oos = sess.get_outputs()
    for oo in oos:
        print(oo.name, oo.shape)

model = onnxruntime.InferenceSession("./model.onnx")

print_input_output(model)
  • output
input [1, 3, 224, 224]
-------------------------------
1232 [1, 384]
  • as we can see there is only one input named input

@nietras
Copy link

nietras commented Apr 20, 2023

I'd be very interested in these models too being readily available in onnx format 🙏

@oylz
Copy link

oylz commented Apr 20, 2023

@dnth Have you solved it? Sorry, I have OCD. If this question is not closed, I will always think about it.

@dnth
Copy link
Author

dnth commented Apr 20, 2023

Hello @oylz sorry to keep you waiting. You solutions works and we were able to export the onnx model. Thank you so much for the help!

@dnth dnth closed this as completed Apr 20, 2023
@dbickson
Copy link

We created a short tutorial notebook to help people evaluate the great quality of the Dino v2 model by computing embedding via Dino, clustering it and viewing the result.
Here is our example: https://colab.research.google.com/gist/dbickson/dcc77eaeeee05252f533d794b7f53e4e/dinov2_notebook.ipynb

@woctezuma
Copy link

woctezuma commented Apr 21, 2023

We created a short tutorial notebook to help people evaluate the great quality of the Dino v2 model by computing embedding via Dino, clustering it and viewing the result. Here is our example: https://colab.research.google.com/gist/dbickson/dcc77eaeeee05252f533d794b7f53e4e/dinov2_notebook.ipynb

This looks really cool! I see the notebook is also hosted at visual-layer/fastdup (Dino v2 Embeddings).

Edit: albeit a bit slow, it seems, maybe because it is not using the GPU.
According to Google Colab, this one-time initial procedure took 30 minutes, which is more than advertised for .run() with the default arguments. 😅 I don't know if it is normal. If so, the ability to use a GPU would be nice. Maybe it is already possible?

Picture

%pip install -qq fastdup
!wget https://thor.robots.ox.ac.uk/~vgg/data/pets/images.tar.gz
!tar xf images.tar.gz
import fastdup

fd = fastdup.create(input_dir="images/", work_dir="fastdup_work_dir/")
fd.run(model_path='dinov2s', cc_threshold=0.9, overwrite=True)
Dataset Analysis Summary: 

    Dataset contains 7390 images
    Valid images are 99.92% (7,384) of the data, invalid are 0.08% (6) of the data
    For a detailed analysis, use `.invalid_instances()`.

    Similarity:  2.84% (210) belong to 9 similarity clusters (components).
    97.16% (7,180) images do not belong to any similarity cluster.
    Largest cluster has 34 (0.46%) images.
    For a detailed analysis, use `.connected_components()`
(similarity threshold used is 0.9, connected component threshold used is 0.9).

    Outliers: 6.28% (464) of images are possible outliers, and fall in the bottom 5.00% of similarity values.
    For a detailed list of outliers, use `.outliers()`.
filenames, feature_vec = fastdup.load_binary_feature("fastdup_work_dir/atrain_features.dat", d=384)  
print("Feature vector matrix dimensions", feature_vec.shape)
Read a total of  7384 images
Feature vector matrix dimensions (7384, 384)

Visualization is much faster. 😎

fd.vis.component_gallery(keep_aspect_ratio=True)
100%|██████████| 20/20 [00:00<00:00, 22.55it/s]

Finished OK. Components are stored as image files fastdup_work_dir/galleries/components_[index].jpg
Stored components visual view in  fastdup_work_dir/galleries/components.html
Execution time in seconds 2.3

Not sure if inference is supported to find the nearest neighbors of an input image after generating the .dat file of features.

@ChocoL0rd
Copy link

How can i do the same matching across images as in article in figure 10 on page 18?
As i get it, math patch-level features - means get some hidden layers in first and second images and then compute some distance between them (like cosine) with threshold.

@dukeeagle
Copy link

Seconding @nietras ! Would love to see if anyone's already exported ONNX models. Happy to provide Google Drive storage space so everyone can get access

@PeterKim1
Copy link

@dbickson you can write a wrapper class, and then export it

import torch
from PIL import Image
import torchvision.transforms as T
import hubconf

class xyz_model(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model 

    def forward(self, tensor):
        ff = self.model(tensor)
        return ff

model = hubconf.dinov2_vitl14()
mm = xyz_model(model).to('cpu')
mm.eval()
....do export

Hi. I want to use your codes, but i can't import hubconf.

How to import hubconf?

@barbolo
Copy link

barbolo commented Nov 29, 2023

@sefaburakokcu
Copy link

Seconding @nietras ! Would love to see if anyone's already exported ONNX models. Happy to provide Google Drive storage space so everyone can get access

Hi @dukeeagle!
I have exported the models to ONNX. You can also perform exports using https://github.com/sefaburakokcu/dinov2. Additionally, the exported models can be downloaded from:

Hugging Face

Google Drive

You can run the ONNX models by utilizing https://github.com/sefaburakokcu/dinov2_onnx/.

@barbolo
Copy link

barbolo commented Jan 7, 2024

@sefaburakokcu have you exported ONNX outputs with both the classtoken + patchtokens?

@sefaburakokcu
Copy link

Hi @barbolo,

I have exported only feature extraction models.

@sammilei
Copy link

@sefaburakokcu have you exported ONNX outputs with both the classtoken + patchtokens?

Would be nice to have some guidance for this.

@barbolo
Copy link

barbolo commented Apr 1, 2024

I've exported class token + patch tokens: #167 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests