Skip to content

Commit

Permalink
style: use isort
Browse files Browse the repository at this point in the history
  • Loading branch information
borisdayma committed Nov 30, 2021
1 parent fb1fbca commit d209547
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 58 deletions.
14 changes: 0 additions & 14 deletions .github/workflows/black.yml

This file was deleted.

20 changes: 6 additions & 14 deletions app/gradio/app_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,18 @@

import random

import gradio as gr
import jax
import flax.linen as nn
from flax.training.common_utils import shard
import numpy as np
from flax.jax_utils import replicate

from transformers import BartTokenizer

from flax.training.common_utils import shard
from PIL import Image, ImageDraw, ImageFont
import numpy as np

from vqgan_jax.modeling_flax_vqgan import VQModel
from dalle_mini.model import CustomFlaxBartForConditionalGeneration

# ## CLIP Scoring
from transformers import CLIPProcessor, FlaxCLIPModel

import gradio as gr

from PIL import Image, ImageDraw, ImageFont
from transformers import BartTokenizer, CLIPProcessor, FlaxCLIPModel
from vqgan_jax.modeling_flax_vqgan import VQModel

from dalle_mini.model import CustomFlaxBartForConditionalGeneration

DALLE_REPO = "flax-community/dalle-mini"
DALLE_COMMIT_ID = "4d34126d0df8bc4a692ae933e3b902a1fa8b6114"
Expand Down
3 changes: 2 additions & 1 deletion app/streamlit/app.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#!/usr/bin/env python
# coding: utf-8

from .backend import ServiceError, get_images_from_backend
import streamlit as st

from .backend import ServiceError, get_images_from_backend

st.sidebar.markdown(
"""
<style>
Expand Down
5 changes: 3 additions & 2 deletions app/streamlit/backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import requests
from io import BytesIO
import base64
from io import BytesIO

import requests
from PIL import Image


Expand Down
6 changes: 4 additions & 2 deletions dalle_mini/data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from dataclasses import dataclass, field
from datasets import load_dataset, Dataset
from functools import partial
import numpy as np

import jax
import jax.numpy as jnp
import numpy as np
from datasets import Dataset, load_dataset
from flax.training.common_utils import shard

from .text import TextNormalizer


Expand Down
14 changes: 6 additions & 8 deletions dalle_mini/model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import jax
import flax.linen as nn

import jax
from transformers import BartConfig
from transformers.models.bart.modeling_flax_bart import (
FlaxBartModule,
FlaxBartForConditionalGenerationModule,
FlaxBartForConditionalGeneration,
FlaxBartEncoder,
FlaxBartDecoder,
FlaxBartEncoder,
FlaxBartForConditionalGeneration,
FlaxBartForConditionalGenerationModule,
FlaxBartModule,
)

from transformers import BartConfig


class CustomFlaxBartModule(FlaxBartModule):
def setup(self):
Expand Down
8 changes: 5 additions & 3 deletions dalle_mini/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
Utilities for processing text.
"""

import html
import math
import random
import re
from pathlib import Path
from unidecode import unidecode

import re, math, random, html
import ftfy

from huggingface_hub import hf_hub_download
from unidecode import unidecode

# based on wiki word occurence
person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
Expand Down
22 changes: 8 additions & 14 deletions tools/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,31 @@
Script adapted from run_summarization_flax.py
"""

import os
import json
import logging
import os
import sys
import time
from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Callable, Optional
import json

import datasets
from datasets import Dataset
from tqdm import tqdm
from dataclasses import asdict

import jax
import jax.numpy as jnp
import optax
import transformers
import wandb
from datasets import Dataset
from flax import jax_utils, traverse_util
from flax.serialization import from_bytes, to_bytes
from flax.jax_utils import unreplicate
from flax.serialization import from_bytes, to_bytes
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard_prng_key
from transformers import (
AutoTokenizer,
HfArgumentParser,
)
from tqdm import tqdm
from transformers import AutoTokenizer, HfArgumentParser
from transformers.models.bart.modeling_flax_bart import BartConfig

import wandb

from dalle_mini.data import Dataset
from dalle_mini.model import CustomFlaxBartForConditionalGeneration

Expand Down

0 comments on commit d209547

Please sign in to comment.