-
Notifications
You must be signed in to change notification settings - Fork 3
/
transform.go
224 lines (192 loc) · 5.65 KB
/
transform.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
// Package transform 은 데이터 타입은 변환하는 함수들이 정의 되어 있습니다.
package transform
import (
"context"
"fmt"
"net/url"
"regexp"
"strconv"
"strings"
"github.com/codingpot/paperswithcode-go/v2/models"
"github.com/codingpot/pr12er/server/pkg/pr12er"
"github.com/rocketlaunchr/google-search"
log "github.com/sirupsen/logrus"
"golang.org/x/time/rate"
"google.golang.org/api/customsearch/v1"
"google.golang.org/api/option"
)
var (
prIDRegexp = regexp.MustCompile(`pr[- ]*(\d+)`)
arxivIDRegexp = regexp.MustCompile(`\d+\.\d+`)
)
func frameworkToEnum(paperFramework string) pr12er.Framework {
switch paperFramework {
case "tf":
return pr12er.Framework_FRAMEWORK_TENSORFLOW
case "pytorch":
return pr12er.Framework_FRAMEWORK_PYTORCH
case "":
return pr12er.Framework_FRAMEWORK_UNSPECIFIED
default:
return pr12er.Framework_FRAMEWORK_OTHERS
}
}
// Repositories paperswithcode Repository 모델을 pr12er 에 맞는 모델로 변환.
func Repositories(repositories []models.Repository) []*pr12er.Repository {
pr12erRepos := make([]*pr12er.Repository, len(repositories))
for idx, repo := range repositories {
pr12erRepos[idx] = &pr12er.Repository{
IsOfficial: repo.IsOfficial,
Url: repo.URL,
Owner: ownerFromURL(repo.URL),
Framework: frameworkToEnum(repo.Framework),
NumberOfStars: int64(repo.Stars),
Description: repo.Description,
}
}
return pr12erRepos
}
// TODO: it may needs more sophisticated logic.
// It assumes the input URL is well formed (github.com/<owner>/<repo name>).
func ownerFromURL(url string) string {
splits := strings.Split(url, "/")
if len(splits) < 2 {
return ""
}
return splits[len(splits)-2]
}
// Methods paperswithcode 메소드를 pr12er 메소드로 변경.
func Methods(methods []models.Method) []*pr12er.Method {
pr12erMethods := make([]*pr12er.Method, len(methods))
for idx, method := range methods {
pr12erMethods[idx] = &pr12er.Method{
Name: method.Name,
FullName: method.FullName,
Description: method.Description,
}
}
return pr12erMethods
}
// ExtractPaperIDs Google Search with the title and gets ArxivIDs.
//
// For example,
// "PR-274: On mutual information maximization for representation learning"
// => []string{"1907.13625", "2103.04537", "1910.08350"}.
func ExtractPaperIDs(title string) ([]string, error) {
searchTerm := prIDRegexp.ReplaceAllString(strings.ToLower(title), "") + " site:arxiv.org"
search, err := googlesearch.Search(context.Background(), searchTerm)
if err != nil {
return nil, err
}
maxLen := 3
if len(search) < maxLen {
maxLen = len(search)
}
var paperIDs []string
// save paperIDs that were already seen.
seenPaperID := map[string]bool{}
for i := 0; i < maxLen; i++ {
arxivID, _ := ExtractArxivIDFromURL(search[i].URL)
if seenPaperID[arxivID] {
continue
}
paperIDs = append(paperIDs, arxivID)
seenPaperID[arxivID] = true
}
return paperIDs, nil
}
// ExtractPaperIDsViaProgrammableSearch returns ArxivIDs from Programmable Search API.
func ExtractPaperIDsViaProgrammableSearch(title, cx, apiKey string, limiter *rate.Limiter) ([]string, error) {
searchTerm := prIDRegexp.ReplaceAllString(strings.ToLower(title), "")
svc, err := customsearch.NewService(context.Background(), option.WithAPIKey(apiKey))
if err != nil {
return nil, err
}
if err := limiter.Wait(context.Background()); err != nil {
return nil, err
}
do, err := svc.Cse.Siterestrict.List().Cx(cx).Q(searchTerm).Do()
if err != nil {
return nil, err
}
var arxivIDs []string
seenArxivIDs := map[string]bool{}
for _, item := range do.Items {
arxivID, err := ExtractArxivIDFromURL(item.Link)
if err != nil {
log.WithError(err).Warn("failed to parse arxiv ID")
continue
}
if arxivID != "" {
if !seenArxivIDs[arxivID] {
seenArxivIDs[arxivID] = true
arxivIDs = append(arxivIDs, arxivID)
}
}
}
return arxivIDs, nil
}
// ExtractArxivIDFromURL extracts ArxivID from the URL.
//
// For example,
// https://arxiv.org/abs/2102.03732 => 2102.03732
func ExtractArxivIDFromURL(url string) (string, error) {
if !strings.Contains(url, "arxiv.org") {
return "", fmt.Errorf("%s does not contain arxiv.org", url)
}
arxivID := arxivIDRegexp.FindString(url)
return arxivID, nil
}
// ExtractPRID extracts PR ID from the title.
//
// For example,
//
// PR-274: On mutual information maximization for representation learning
// => 274.
func ExtractPRID(title string) (int32, error) {
submatch := prIDRegexp.FindStringSubmatch(strings.ToLower(title))
if len(submatch) < 2 {
return 0, fmt.Errorf("expected PR-XXX but got %s", title)
}
//nolint:gosec
atoi, err := strconv.Atoi(submatch[1])
if err != nil {
return 0, err
}
return int32(atoi), nil
}
type InvalidYouTubeLinkError struct {
url string
}
func (e InvalidYouTubeLinkError) Error() string {
return fmt.Sprintf("invalid YouTubeID is found in %s", e.url)
}
// ExtractYouTubeID extracts videoID from YouTube link
//
// For example,
//
// https://www.youtube.com/watch?v=rtuJqQDWmIA => rtuJqQDWmIA
// https://youtube.com/watch?v=rtuJqQDWmIA => rtuJqQDWmIA
// https://youtu.be/rtuJqQDWmIA => rtuJqQDWmIA
func ExtractYouTubeID(link string) (string, error) {
parse, err := url.Parse(link)
if err != nil {
return "", err
}
errNotYouTubeLink := InvalidYouTubeLinkError{link}
if strings.Contains(parse.Hostname(), "youtube") {
youtubeID := parse.Query().Get("v")
if youtubeID == "" {
return "", errNotYouTubeLink
}
return youtubeID, nil
}
if strings.Contains(parse.Hostname(), "youtu.be") {
youtubeID := strings.TrimPrefix(parse.Path, "/")
if youtubeID == "" {
return "", errNotYouTubeLink
}
return youtubeID, nil
}
return "", errNotYouTubeLink
}