Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving generation speed #23

Closed
JMGaljaard opened this issue Jun 9, 2023 · 3 comments
Closed

Improving generation speed #23

JMGaljaard opened this issue Jun 9, 2023 · 3 comments
Labels
enhancement New feature or request good first issue Good for newcomers

Comments

@JMGaljaard
Copy link

JMGaljaard commented Jun 9, 2023

Dear authors,

Let me start by thanking you for the open-source release of GReaT. I found an implementation detail about the generation of samples, especially on larger datasets.

Problem Description

Looking at the GPU utilization I found that the CPU workload (everything outside of sampling the model) takes increasingly longer. (using nvtop, GPU utilization becomes worse with more/higher sampling iterations).

Proposed Solution

Digging in the code I found that the accumulator (df_gen) and generated (pd.DataFrame(td)) data frames are concatenated in each iteration.

https://github.com/kathrinse/be_great/blob/c568617763ba954fb39fc6b6e222e3abaef0886a/be_great/great.py#LL147C21-L147C21

df_gen = pd.concat([df_gen, pd.DataFrame(td)], ignore_index=True, axis=0)

This incurs O(N^2) overhead (each time memory is allocated for a new DataFrame that can contain all rows). This can be resolved by creating a list of data frames and concatenating them at the end of the generation process. For example:

for GReaT.sample this would require a minor change, similar to the following:

# Create an accumulation list for generated data
dfs = [] 
...
while n > already_generated:
    ...
    df_gen = _convert_text_to_tabular_data(text_data, df_gen)
    ...
    dfs.append(df_gen)
    already_generated += len(dfs[-1])
    pbar.update(len(dfs[-1]))
    
df_gen = pd.concat(dfs)
df_gen = df_gen.reset_index(drop=True)    
...

The _convert_text_to_tabular_data can be improved similarly by making it return a DataFrame that is constructed from a list of dictionaries.

def _convert_text_to_tabular_data(text: tp.List[str], df_gen: pd.DataFrame) -> pd.DataFrame:
    ...
    generated = []
    ...
    for t in text:
        ...
        generated.append(td)
    gen_df = pd.DataFrame(generated)

This way for a dataset containing 20K+ samples, generation time went from 40+ minutes to about 3 minutes. Smaller datasets also seem to benefit, but this is less pronounced as the overhead grows linearly with the sampling iteration.

Example implementation

Looking at related work, it seems like the RealTabFormers implementation provides an example of this.

https://github.com/worldbank/REaLTabFormer/blob/bf1a38ef8f202372956ac57a363289c505967982/src/realtabformer/rtf_sampler.py#L610-L674

Side note

Likely this could also (slightly) improve GReaT's performance in Appendix B.5 of your paper for inference/generation.

@unnir unnir added enhancement New feature or request good first issue Good for newcomers labels Jun 9, 2023
@unnir
Copy link
Collaborator

unnir commented Jun 9, 2023

Dear @JMGaljaard

Thank you for suggesting an improvement to our framework!
Especially, I thank you for reporting it in a such understandable way, I appreciate it.

We will be happy to receive & accept a PR :)

@Madnex
Copy link

Madnex commented Jul 17, 2023

Is someone working on this? Otherwise I would be happy to create a PR based on the suggestion :)

@unnir
Copy link
Collaborator

unnir commented Jul 18, 2023

@Madnex please go ahead! We will be happy to receive your PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

3 participants